Fix state array length for split

This commit is contained in:
Matthew Honnibal 2018-04-03 15:43:13 +02:00
parent 7ff4d8967f
commit 6cc79fc244

View File

@ -1,3 +1,4 @@
# cython: infer_types=True
from libc.string cimport memcpy, memset, memmove
from libc.stdlib cimport malloc, calloc, free, realloc
from libc.stdint cimport uint32_t, uint64_t
@ -13,6 +14,7 @@ from ..symbols cimport punct
from ..attrs cimport IS_SPACE
from ..typedefs cimport attr_t
include "compile_time.pxi"
cdef inline bint is_space_token(const TokenC* token) nogil:
return Lexeme.c_check_flag(token.lex, IS_SPACE)
@ -55,12 +57,13 @@ cdef cppclass StateC:
__init__(const TokenC* sent, int length) nogil:
cdef int PADDING = 5
this._buffer = <int*>calloc(length + (PADDING * 2), sizeof(int))
this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int))
this.was_split = <int*>calloc(length + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint))
this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC))
this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity))
cdef int length_with_split = length * MAX_SPLIT
this._buffer = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
this._stack = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
this.was_split = <int*>calloc(length_with_split + (PADDING * 2), sizeof(int))
this.shifted = <bint*>calloc(length_with_split + (PADDING * 2), sizeof(bint))
this._sent = <TokenC*>calloc(length_with_split + (PADDING * 2), sizeof(TokenC))
this._ents = <Entity*>calloc(length_with_split + (PADDING * 2), sizeof(Entity))
if not (this._buffer and this._stack and this.shifted
and this._sent and this._ents):
with gil:
@ -69,7 +72,7 @@ cdef cppclass StateC:
memset(&this._hist, 0, sizeof(this._hist))
this.offset = 0
cdef int i
for i in range(length + (PADDING * 2)):
for i in range(length_with_split + (PADDING * 2)):
this._ents[i].end = -1
this._sent[i].l_edge = i
this._sent[i].r_edge = i
@ -82,7 +85,7 @@ cdef cppclass StateC:
this.shifted += PADDING
this.length = length
this.buffer_length = length
this.max_split = 0
this.max_split = MAX_SPLIT
this._n_until_break = -1
this._s_i = 0
this._b_i = 0
@ -94,6 +97,8 @@ cdef cppclass StateC:
for i in range(length):
this._sent[i] = sent[i]
this._buffer[i] = i
for j in range(1, MAX_SPLIT):
this._sent[j*length +i] = sent[i]
for i in range(length, length+PADDING):
this._sent[i].lex = &EMPTY_LEXEME
@ -199,7 +204,14 @@ cdef cppclass StateC:
return True
int can_split() nogil const:
return 0
if this.max_split < 2:
return 0
elif this.buffer_length == 0:
return 0
elif this.was_split[this.B(0)]:
return 0
else:
return 1
int S(int i) nogil const:
if i >= this._s_i:
@ -406,6 +418,7 @@ cdef cppclass StateC:
# buffer[4:8] = buffer[5:9]
# buffer[8:9] = new_tokens
cdef int target = this.B(i)
this.was_split[target] = n
this._b_i -= n
memmove(&this._buffer[this._b_i],
&this._buffer[this._b_i+n], (i+1)*sizeof(this._buffer[0]))
@ -498,12 +511,13 @@ cdef cppclass StateC:
void clone(const StateC* src) nogil:
this.length = src.length
cdef int length_with_split = this.length * this.max_split
this.buffer_length = src.buffer_length
memcpy(this._sent, src._sent, this.length * sizeof(TokenC))
memcpy(this._stack, src._stack, this.length * sizeof(int))
memcpy(this._buffer, src._buffer, this.length * sizeof(int))
memcpy(this._ents, src._ents, this.length * sizeof(Entity))
memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0]))
memcpy(this._sent, src._sent, length_with_split * sizeof(TokenC))
memcpy(this._stack, src._stack, length_with_split * sizeof(int))
memcpy(this._buffer, src._buffer, length_with_split * sizeof(int))
memcpy(this._ents, src._ents, length_with_split * sizeof(Entity))
memcpy(this.shifted, src.shifted, length_with_split * sizeof(this.shifted[0]))
this._b_i = src._b_i
this._s_i = src._s_i
this._e_i = src._e_i