diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 388c8133e..8b5a071dc 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -1,3 +1,4 @@ +# cython: infer_types=True from libc.string cimport memcpy, memset, memmove from libc.stdlib cimport malloc, calloc, free, realloc from libc.stdint cimport uint32_t, uint64_t @@ -13,6 +14,7 @@ from ..symbols cimport punct from ..attrs cimport IS_SPACE from ..typedefs cimport attr_t +include "compile_time.pxi" cdef inline bint is_space_token(const TokenC* token) nogil: return Lexeme.c_check_flag(token.lex, IS_SPACE) @@ -55,12 +57,13 @@ cdef cppclass StateC: __init__(const TokenC* sent, int length) nogil: cdef int PADDING = 5 - this._buffer = calloc(length + (PADDING * 2), sizeof(int)) - this._stack = calloc(length + (PADDING * 2), sizeof(int)) - this.was_split = calloc(length + (PADDING * 2), sizeof(int)) - this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) - this._sent = calloc(length + (PADDING * 2), sizeof(TokenC)) - this._ents = calloc(length + (PADDING * 2), sizeof(Entity)) + cdef int length_with_split = length * MAX_SPLIT + this._buffer = calloc(length_with_split + (PADDING * 2), sizeof(int)) + this._stack = calloc(length_with_split + (PADDING * 2), sizeof(int)) + this.was_split = calloc(length_with_split + (PADDING * 2), sizeof(int)) + this.shifted = calloc(length_with_split + (PADDING * 2), sizeof(bint)) + this._sent = calloc(length_with_split + (PADDING * 2), sizeof(TokenC)) + this._ents = calloc(length_with_split + (PADDING * 2), sizeof(Entity)) if not (this._buffer and this._stack and this.shifted and this._sent and this._ents): with gil: @@ -69,7 +72,7 @@ cdef cppclass StateC: memset(&this._hist, 0, sizeof(this._hist)) this.offset = 0 cdef int i - for i in range(length + (PADDING * 2)): + for i in range(length_with_split + (PADDING * 2)): this._ents[i].end = -1 this._sent[i].l_edge = i this._sent[i].r_edge = i @@ -82,7 +85,7 @@ cdef cppclass StateC: this.shifted += PADDING this.length = length this.buffer_length = length - this.max_split = 0 + this.max_split = MAX_SPLIT this._n_until_break = -1 this._s_i = 0 this._b_i = 0 @@ -94,6 +97,8 @@ cdef cppclass StateC: for i in range(length): this._sent[i] = sent[i] this._buffer[i] = i + for j in range(1, MAX_SPLIT): + this._sent[j*length +i] = sent[i] for i in range(length, length+PADDING): this._sent[i].lex = &EMPTY_LEXEME @@ -199,7 +204,14 @@ cdef cppclass StateC: return True int can_split() nogil const: - return 0 + if this.max_split < 2: + return 0 + elif this.buffer_length == 0: + return 0 + elif this.was_split[this.B(0)]: + return 0 + else: + return 1 int S(int i) nogil const: if i >= this._s_i: @@ -406,6 +418,7 @@ cdef cppclass StateC: # buffer[4:8] = buffer[5:9] # buffer[8:9] = new_tokens cdef int target = this.B(i) + this.was_split[target] = n this._b_i -= n memmove(&this._buffer[this._b_i], &this._buffer[this._b_i+n], (i+1)*sizeof(this._buffer[0])) @@ -498,12 +511,13 @@ cdef cppclass StateC: void clone(const StateC* src) nogil: this.length = src.length + cdef int length_with_split = this.length * this.max_split this.buffer_length = src.buffer_length - memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) - memcpy(this._stack, src._stack, this.length * sizeof(int)) - memcpy(this._buffer, src._buffer, this.length * sizeof(int)) - memcpy(this._ents, src._ents, this.length * sizeof(Entity)) - memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0])) + memcpy(this._sent, src._sent, length_with_split * sizeof(TokenC)) + memcpy(this._stack, src._stack, length_with_split * sizeof(int)) + memcpy(this._buffer, src._buffer, length_with_split * sizeof(int)) + memcpy(this._ents, src._ents, length_with_split * sizeof(Entity)) + memcpy(this.shifted, src.shifted, length_with_split * sizeof(this.shifted[0])) this._b_i = src._b_i this._s_i = src._s_i this._e_i = src._e_i