Fix state array length for split

This commit is contained in:
Matthew Honnibal 2018-04-03 15:43:13 +02:00
parent 7ff4d8967f
commit 6cc79fc244

View File

@ -1,3 +1,4 @@
# cython: infer_types=True
from libc.string cimport memcpy, memset, memmove from libc.string cimport memcpy, memset, memmove
from libc.stdlib cimport malloc, calloc, free, realloc from libc.stdlib cimport malloc, calloc, free, realloc
from libc.stdint cimport uint32_t, uint64_t from libc.stdint cimport uint32_t, uint64_t
@ -13,6 +14,7 @@ from ..symbols cimport punct
from ..attrs cimport IS_SPACE from ..attrs cimport IS_SPACE
from ..typedefs cimport attr_t from ..typedefs cimport attr_t
include "compile_time.pxi"
cdef inline bint is_space_token(const TokenC* token) nogil: cdef inline bint is_space_token(const TokenC* token) nogil:
return Lexeme.c_check_flag(token.lex, IS_SPACE) return Lexeme.c_check_flag(token.lex, IS_SPACE)
@ -55,12 +57,13 @@ cdef cppclass StateC:
__init__(const TokenC* sent, int length) nogil: __init__(const TokenC* sent, int length) nogil:
cdef int PADDING = 5 cdef int PADDING = 5
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int)) cdef int length_with_split = length * MAX_SPLIT
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int)) this._buffer = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
this.was_split = <int*>calloc(length + (PADDING * 2), sizeof(int)) this._stack = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint)) this.was_split = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC)) this.shifted = <bint*>calloc(length_with_split + (PADDING * 2), sizeof(bint))
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity)) this._sent = <TokenC*>calloc(length_with_split + (PADDING * 2), sizeof(TokenC))
this._ents = <Entity*>calloc(length_with_split + (PADDING * 2), sizeof(Entity))
if not (this._buffer and this._stack and this.shifted if not (this._buffer and this._stack and this.shifted
and this._sent and this._ents): and this._sent and this._ents):
with gil: with gil:
@ -69,7 +72,7 @@ cdef cppclass StateC:
memset(&this._hist, 0, sizeof(this._hist)) memset(&this._hist, 0, sizeof(this._hist))
this.offset = 0 this.offset = 0
cdef int i cdef int i
for i in range(length + (PADDING * 2)): for i in range(length_with_split + (PADDING * 2)):
this._ents[i].end = -1 this._ents[i].end = -1
this._sent[i].l_edge = i this._sent[i].l_edge = i
this._sent[i].r_edge = i this._sent[i].r_edge = i
@ -82,7 +85,7 @@ cdef cppclass StateC:
this.shifted += PADDING this.shifted += PADDING
this.length = length this.length = length
this.buffer_length = length this.buffer_length = length
this.max_split = 0 this.max_split = MAX_SPLIT
this._n_until_break = -1 this._n_until_break = -1
this._s_i = 0 this._s_i = 0
this._b_i = 0 this._b_i = 0
@ -94,6 +97,8 @@ cdef cppclass StateC:
for i in range(length): for i in range(length):
this._sent[i] = sent[i] this._sent[i] = sent[i]
this._buffer[i] = i this._buffer[i] = i
for j in range(1, MAX_SPLIT):
this._sent[j*length +i] = sent[i]
for i in range(length, length+PADDING): for i in range(length, length+PADDING):
this._sent[i].lex = &EMPTY_LEXEME this._sent[i].lex = &EMPTY_LEXEME
@ -199,7 +204,14 @@ cdef cppclass StateC:
return True return True
int can_split() nogil const: int can_split() nogil const:
if this.max_split < 2:
return 0 return 0
elif this.buffer_length == 0:
return 0
elif this.was_split[this.B(0)]:
return 0
else:
return 1
int S(int i) nogil const: int S(int i) nogil const:
if i >= this._s_i: if i >= this._s_i:
@ -406,6 +418,7 @@ cdef cppclass StateC:
# buffer[4:8] = buffer[5:9] # buffer[4:8] = buffer[5:9]
# buffer[8:9] = new_tokens # buffer[8:9] = new_tokens
cdef int target = this.B(i) cdef int target = this.B(i)
this.was_split[target] = n
this._b_i -= n this._b_i -= n
memmove(&this._buffer[this._b_i], memmove(&this._buffer[this._b_i],
&this._buffer[this._b_i+n], (i+1)*sizeof(this._buffer[0])) &this._buffer[this._b_i+n], (i+1)*sizeof(this._buffer[0]))
@ -498,12 +511,13 @@ cdef cppclass StateC:
void clone(const StateC* src) nogil: void clone(const StateC* src) nogil:
this.length = src.length this.length = src.length
cdef int length_with_split = this.length * this.max_split
this.buffer_length = src.buffer_length this.buffer_length = src.buffer_length
memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) memcpy(this._sent, src._sent, length_with_split * sizeof(TokenC))
memcpy(this._stack, src._stack, this.length * sizeof(int)) memcpy(this._stack, src._stack, length_with_split * sizeof(int))
memcpy(this._buffer, src._buffer, this.length * sizeof(int)) memcpy(this._buffer, src._buffer, length_with_split * sizeof(int))
memcpy(this._ents, src._ents, this.length * sizeof(Entity)) memcpy(this._ents, src._ents, length_with_split * sizeof(Entity))
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0])) memcpy(this.shifted, src.shifted, length_with_split * sizeof(this.shifted[0]))
this._b_i = src._b_i this._b_i = src._b_i
this._s_i = src._s_i this._s_i = src._s_i
this._e_i = src._e_i this._e_i = src._e_i