mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-30 18:03:04 +03:00
Support history features in stateclass
This commit is contained in:
parent
6aa6a5bc25
commit
ee41e4fea7
|
@ -1,4 +1,4 @@
|
||||||
from libc.string cimport memcpy, memset
|
from libc.string cimport memcpy, memset, memmove
|
||||||
from libc.stdlib cimport malloc, calloc, free
|
from libc.stdlib cimport malloc, calloc, free
|
||||||
from libc.stdint cimport uint32_t, uint64_t
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
|
|
||||||
|
@ -15,6 +15,23 @@ from ..typedefs cimport attr_t
|
||||||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
cdef inline bint is_space_token(const TokenC* token) nogil:
|
||||||
return Lexeme.c_check_flag(token.lex, IS_SPACE)
|
return Lexeme.c_check_flag(token.lex, IS_SPACE)
|
||||||
|
|
||||||
|
cdef struct RingBufferC:
|
||||||
|
int[8] data
|
||||||
|
int i
|
||||||
|
int default
|
||||||
|
|
||||||
|
cdef inline int ring_push(RingBufferC* ring, int value) nogil:
|
||||||
|
ring.data[ring.i] = value
|
||||||
|
ring.i += 1
|
||||||
|
if ring.i >= 8:
|
||||||
|
ring.i = 0
|
||||||
|
|
||||||
|
cdef inline int ring_get(RingBufferC* ring, int i) nogil:
|
||||||
|
if i >= ring.i:
|
||||||
|
return ring.default
|
||||||
|
else:
|
||||||
|
return ring.data[ring.i-i]
|
||||||
|
|
||||||
|
|
||||||
cdef cppclass StateC:
|
cdef cppclass StateC:
|
||||||
int* _stack
|
int* _stack
|
||||||
|
@ -23,6 +40,7 @@ cdef cppclass StateC:
|
||||||
TokenC* _sent
|
TokenC* _sent
|
||||||
Entity* _ents
|
Entity* _ents
|
||||||
TokenC _empty_token
|
TokenC _empty_token
|
||||||
|
RingBufferC _hist
|
||||||
int length
|
int length
|
||||||
int offset
|
int offset
|
||||||
int _s_i
|
int _s_i
|
||||||
|
@ -37,6 +55,7 @@ cdef cppclass StateC:
|
||||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||||
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
||||||
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
|
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
|
||||||
|
memset(&this._hist, 0, sizeof(this._hist))
|
||||||
this.offset = 0
|
this.offset = 0
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(length + (PADDING * 2)):
|
for i in range(length + (PADDING * 2)):
|
||||||
|
@ -271,7 +290,14 @@ cdef cppclass StateC:
|
||||||
sig[8] = this.B_(0)[0]
|
sig[8] = this.B_(0)[0]
|
||||||
sig[9] = this.E_(0)[0]
|
sig[9] = this.E_(0)[0]
|
||||||
sig[10] = this.E_(1)[0]
|
sig[10] = this.E_(1)[0]
|
||||||
return hash64(sig, sizeof(sig), this._s_i)
|
return hash64(sig, sizeof(sig), this._s_i) \
|
||||||
|
+ hash64(<void*>&this._hist, sizeof(RingBufferC), 1)
|
||||||
|
|
||||||
|
void push_hist(int act) nogil:
|
||||||
|
ring_push(&this._hist, act)
|
||||||
|
|
||||||
|
int get_hist(int i) nogil:
|
||||||
|
return ring_get(&this._hist, i)
|
||||||
|
|
||||||
void push() nogil:
|
void push() nogil:
|
||||||
if this.B(0) != -1:
|
if this.B(0) != -1:
|
||||||
|
|
|
@ -4,6 +4,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from libc.string cimport memcpy, memset
|
from libc.string cimport memcpy, memset
|
||||||
from libc.stdint cimport uint32_t, uint64_t
|
from libc.stdint cimport uint32_t, uint64_t
|
||||||
|
import numpy
|
||||||
|
|
||||||
from ..vocab cimport EMPTY_LEXEME
|
from ..vocab cimport EMPTY_LEXEME
|
||||||
from ..structs cimport Entity
|
from ..structs cimport Entity
|
||||||
|
@ -38,6 +39,13 @@ cdef class StateClass:
|
||||||
def token_vector_lenth(self):
|
def token_vector_lenth(self):
|
||||||
return self.doc.tensor.shape[1]
|
return self.doc.tensor.shape[1]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def history(self):
|
||||||
|
hist = numpy.ndarray((8,), dtype='i')
|
||||||
|
for i in range(8):
|
||||||
|
hist[i] = self.c.get_hist(i+1)
|
||||||
|
return hist
|
||||||
|
|
||||||
def is_final(self):
|
def is_final(self):
|
||||||
return self.c.is_final()
|
return self.c.is_final()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user