mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-10 15:14:56 +03:00
Fix state array length for split
This commit is contained in:
parent
7ff4d8967f
commit
6cc79fc244
|
@ -1,3 +1,4 @@
|
|||
# cython: infer_types=True
|
||||
from libc.string cimport memcpy, memset, memmove
|
||||
from libc.stdlib cimport malloc, calloc, free, realloc
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
|
@ -13,6 +14,7 @@ from ..symbols cimport punct
|
|||
from ..attrs cimport IS_SPACE
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
include "compile_time.pxi"
|
||||
|
||||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
||||
return Lexeme.c_check_flag(token.lex, IS_SPACE)
|
||||
|
@ -55,12 +57,13 @@ cdef cppclass StateC:
|
|||
|
||||
__init__(const TokenC* sent, int length) nogil:
|
||||
cdef int PADDING = 5
|
||||
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this.was_split = <int*>calloc(length + (PADDING * 2), sizeof(int))
|
||||
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
|
||||
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
|
||||
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
|
||||
cdef int length_with_split = length * MAX_SPLIT
|
||||
this._buffer = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
|
||||
this._stack = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
|
||||
this.was_split = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
|
||||
this.shifted = <bint*>calloc(length_with_split + (PADDING * 2), sizeof(bint))
|
||||
this._sent = <TokenC*>calloc(length_with_split + (PADDING * 2), sizeof(TokenC))
|
||||
this._ents = <Entity*>calloc(length_with_split + (PADDING * 2), sizeof(Entity))
|
||||
if not (this._buffer and this._stack and this.shifted
|
||||
and this._sent and this._ents):
|
||||
with gil:
|
||||
|
@ -69,7 +72,7 @@ cdef cppclass StateC:
|
|||
memset(&this._hist, 0, sizeof(this._hist))
|
||||
this.offset = 0
|
||||
cdef int i
|
||||
for i in range(length + (PADDING * 2)):
|
||||
for i in range(length_with_split + (PADDING * 2)):
|
||||
this._ents[i].end = -1
|
||||
this._sent[i].l_edge = i
|
||||
this._sent[i].r_edge = i
|
||||
|
@ -82,7 +85,7 @@ cdef cppclass StateC:
|
|||
this.shifted += PADDING
|
||||
this.length = length
|
||||
this.buffer_length = length
|
||||
this.max_split = 0
|
||||
this.max_split = MAX_SPLIT
|
||||
this._n_until_break = -1
|
||||
this._s_i = 0
|
||||
this._b_i = 0
|
||||
|
@ -94,6 +97,8 @@ cdef cppclass StateC:
|
|||
for i in range(length):
|
||||
this._sent[i] = sent[i]
|
||||
this._buffer[i] = i
|
||||
for j in range(1, MAX_SPLIT):
|
||||
this._sent[j*length +i] = sent[i]
|
||||
for i in range(length, length+PADDING):
|
||||
this._sent[i].lex = &EMPTY_LEXEME
|
||||
|
||||
|
@ -199,7 +204,14 @@ cdef cppclass StateC:
|
|||
return True
|
||||
|
||||
int can_split() nogil const:
|
||||
return 0
|
||||
if this.max_split < 2:
|
||||
return 0
|
||||
elif this.buffer_length == 0:
|
||||
return 0
|
||||
elif this.was_split[this.B(0)]:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
|
||||
int S(int i) nogil const:
|
||||
if i >= this._s_i:
|
||||
|
@ -406,6 +418,7 @@ cdef cppclass StateC:
|
|||
# buffer[4:8] = buffer[5:9]
|
||||
# buffer[8:9] = new_tokens
|
||||
cdef int target = this.B(i)
|
||||
this.was_split[target] = n
|
||||
this._b_i -= n
|
||||
memmove(&this._buffer[this._b_i],
|
||||
&this._buffer[this._b_i+n], (i+1)*sizeof(this._buffer[0]))
|
||||
|
@ -498,12 +511,13 @@ cdef cppclass StateC:
|
|||
|
||||
void clone(const StateC* src) nogil:
|
||||
this.length = src.length
|
||||
cdef int length_with_split = this.length * this.max_split
|
||||
this.buffer_length = src.buffer_length
|
||||
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
|
||||
memcpy(this._stack, src._stack, this.length * sizeof(int))
|
||||
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
|
||||
memcpy(this._ents, src._ents, this.length * sizeof(Entity))
|
||||
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
|
||||
memcpy(this._sent, src._sent, length_with_split * sizeof(TokenC))
|
||||
memcpy(this._stack, src._stack, length_with_split * sizeof(int))
|
||||
memcpy(this._buffer, src._buffer, length_with_split * sizeof(int))
|
||||
memcpy(this._ents, src._ents, length_with_split * sizeof(Entity))
|
||||
memcpy(this.shifted, src.shifted, length_with_split * sizeof(this.shifted[0]))
|
||||
this._b_i = src._b_i
|
||||
this._s_i = src._s_i
|
||||
this._e_i = src._e_i
|
||||
|
|
Loading…
Reference in New Issue
Block a user