diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index 26991a1fc..2362d925c 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -13,7 +13,6 @@ from ..symbols cimport punct from ..attrs cimport IS_SPACE from ..typedefs cimport attr_t -cdef void _split(StateC* this, int i, int n) nogil cdef inline bint is_space_token(const TokenC* token) nogil: return Lexeme.c_check_flag(token.lex, IS_SPACE) @@ -44,12 +43,14 @@ cdef cppclass StateC: Entity* _ents TokenC _empty_token RingBufferC _hist + int buffer_length + int max_split int length int offset int _s_i int _b_i int _e_i - int _break + int _n_until_break __init__(const TokenC* sent, int length) nogil: cdef int PADDING = 5 @@ -78,7 +79,9 @@ cdef cppclass StateC: this._stack += PADDING this.shifted += PADDING this.length = length - this._break = -1 + this.buffer_length = length + this.max_split = 0 + this._n_until_break = -1 this._s_i = 0 this._b_i = 0 this._e_i = 0 @@ -160,7 +163,9 @@ cdef cppclass StateC: return this._stack[this._s_i - (i+1)] int B(int i) nogil const: - if (i + this._b_i) >= this.length: + if i >= this.buffer_length: + return -1 + if this._n_until_break != -1 and i >= this._n_until_break: return -1 return this._buffer[this._b_i + i] @@ -254,13 +259,13 @@ cdef cppclass StateC: return this._s_i <= 0 bint eol() nogil const: - return this.buffer_length() == 0 + return this.buffer_length == 0 or this.at_break() bint at_break() nogil const: - return this._break != -1 + return this._n_until_break == 0 bint is_final() nogil const: - return this.stack_depth() <= 0 and this._b_i >= this.length + return this.stack_depth() <= 0 and this.buffer_length == 0 bint has_head(int i) nogil const: return this.safe_get(i).head != 0 @@ -282,12 +287,6 @@ cdef cppclass StateC: int stack_depth() nogil const: return this._s_i - int buffer_length() nogil const: - if this._break != -1: - return this._break - this._b_i - else: - return this.length - this._b_i - uint64_t hash() nogil const: cdef TokenC[11] sig sig[0] = this.S_(2)[0] @@ -311,46 +310,62 @@ cdef cppclass StateC: return ring_get(&this._hist, i) void push() nogil: - if this.B(0) != -1: - this._stack[this._s_i] = this.B(0) + if this.buffer_length != 0: + this._stack[this._s_i] = this._buffer[this._b_i] + if this._n_until_break != -1: + this._n_until_break -= 1 this._s_i += 1 this._b_i += 1 + this.buffer_length -= 1 if this.B_(0).sent_start == 1: - this.set_break(this.B(0)) - if this._b_i > this._break: - this._break = -1 + this.set_break(0) void split(int i, int n) nogil: '''Split token i of the buffer into N pieces.''' - # Let's say we've got a length 10 sentence. - # state.split(5, 2) - # Before: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] - # After: [0, 1, 2, 3, 4, 5.0, 5.1, 5.2, 6, 7, 8, 9, 10] - # Sentence grows to length 12. - # Words 6-10 move to positions 8-12 - # Words 0-5 stay where they are. - cdef int PADDING = 5 - cdef int j - # Unwind the padding, so we can work with the original pointer. - this._sent -= PADDING - this._sent = realloc(this._sent, - ((this.length+n+1) + (PADDING * 2)) * sizeof(TokenC)) - for j in range(this.length+PADDING*2, this.length+n+1+PADDING*2): - this._sent[j] = this._empty_token - # Put the start padding back in - this._sent += PADDING - # In our example, we want to move words 6-10 to 8-12. So we must move - # a block of 4 words. - cdef int n_moved = this.length - (i+1) - cdef int move_from = i+1 - cdef int move_to = i+n+1 - memmove(&this._sent[move_to], &this._sent[move_from], - n_moved*sizeof(TokenC)) - # Now copy the token that has been split into its neighbours. - for j in range(i+1, i+n+1): - this._sent[j] = this._sent[i] - # Finally, adjust length. - this.length += n + # Let's say we've got a length 10 sentence. 4 is start of buffer. + # We do: state.split(1, 2) + # + # Old buffer: 4,5,6,7,8,9 + # New buffer: 4,5,13,22,6,7,8,9 + if (this._b_i+5*2) < n: + with gil: + raise NotImplementedError + # Let's say we're at token index 4. this._b_i will be 4, so that we + # point forward into the buffer. To insert, we don't need to reallocate + # -- we have space at the start; we can just shift the tokens between + # where we are at the buffer and where the split starts backwards to + # make room. + # + # For b_i=4, i=1, n=2 we want to have: + # Old buffer: [_, _, _, _, 4, 5, 6, 7, 8, 9] and b_i=4 + # New buffer: [_, _, 4, 5, 13, 22, 6, 7, 8, 9] and b_i=2 + # b_i will always move back by n in total, as that's + # the size of the gap we're creating. + # The number of values we have to copy will be i+1 + # Another way to see it: + # For b_i=4, i=1, n=2 + # buffer[2:4] = buffer[4:6] + # buffer[4:6] = new_tokens + # For b_i=7, i=1, n=1 + # buffer[6:8] = buffer[7:9] + # buffer[8:9] = new_tokens + # For b_i=3, i=1, n=3 + # buffer[0:2] = buffer[3:5] + # buffer[2:5] = new_tokens + # For b_i=5, i=3, n=1 + # buffer[4:8] = buffer[5:9] + # buffer[8:9] = new_tokens + cdef int target = this.B(i) + this._b_i -= n + memmove(&this._buffer[this._b_i], + &this._buffer[this._b_i+n], (i+1)*sizeof(this._buffer[0])) + cdef int subtoken, new_token + for subtoken in range(n): + new_token = (subtoken+1) * this.length + target + this._buffer[this._b_i+(i+1)+subtoken] = new_token + this.buffer_length += n + if this._n_until_break != -1: + this._n_until_break += n void pop() nogil: if this._s_i >= 1: @@ -361,6 +376,9 @@ cdef cppclass StateC: this._buffer[this._b_i] = this.S(0) this._s_i -= 1 this.shifted[this.B(0)] = True + this.buffer_length += 1 + if this._n_until_break != -1: + this._n_until_break += 1 void add_arc(int head, int child, attr_t label) nogil: if this.has_head(child): @@ -424,12 +442,13 @@ cdef cppclass StateC: this._sent[i].ent_type = ent_type void set_break(int i) nogil: - if 0 <= i < this.length: - this._sent[i].sent_start = 1 - this._break = this._b_i + if 0 <= i < this.buffer_length: + this._sent[this.B_(i).l_edge].sent_start = 1 + this._n_until_break = i void clone(const StateC* src) nogil: this.length = src.length + this.buffer_length = src.buffer_length memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) memcpy(this._stack, src._stack, this.length * sizeof(int)) memcpy(this._buffer, src._buffer, this.length * sizeof(int)) @@ -438,7 +457,7 @@ cdef cppclass StateC: this._b_i = src._b_i this._s_i = src._s_i this._e_i = src._e_i - this._break = src._break + this._n_until_break = src._n_until_break this.offset = src.offset this._empty_token = src._empty_token @@ -450,9 +469,9 @@ cdef cppclass StateC: # then make the last space token the head of all others while is_space_token(this.B_(0)) \ - or this.buffer_length() == 0 \ + or this.eol() \ or this.stack_depth() == 0: - if this.buffer_length() == 0: + if this.eol(): # remove the last sentence's root from the stack if this.stack_depth() == 1: this.pop() @@ -463,7 +482,7 @@ cdef cppclass StateC: else: this.unshift() # stack is empty but there is another sentence on the buffer - elif (this.length - this._b_i) >= 1: + elif this.buffer_length != 0: this.push() else: # stack empty and nothing else coming break @@ -483,7 +502,7 @@ cdef cppclass StateC: elif this.stack_depth() == 0: # store all space tokens on the stack until a real token shows up # or the last token on the buffer is reached - while is_space_token(this.B_(0)) and this.buffer_length() > 1: + while is_space_token(this.B_(0)) and this.buffer_length > 1: this.push() # empty the stack by attaching all space tokens to the # first token on the buffer @@ -497,12 +516,12 @@ cdef cppclass StateC: elif this.stack_depth() == 0: # for one token sentences (?) - if this.buffer_length() == 1: + if this.buffer_length == 1: this.push() this.pop() # with an empty stack and a non-empty buffer # only shift is valid anyway - elif (this.length - this._b_i) >= 1: + elif this.buffer_length != 0: this.push() else: # can this even happen? diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index a5af8bb7c..ca144bde2 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -66,7 +66,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nogil: cdef weight_t cost = 0 cdef int i, B_i - for i in range(stcls.buffer_length()): + for i in range(stcls.c.buffer_length): B_i = stcls.B(i) cost += gold.heads[B_i] == target cost += gold.heads[target] == B_i @@ -118,7 +118,7 @@ cdef class Shift: @staticmethod cdef bint is_valid(const StateC* st, attr_t label) nogil: sent_start = st._sent[st.B_(0).l_edge].sent_start - return st.buffer_length() >= 2 and not st.shifted[st.B(0)] and sent_start != 1 + return st.buffer_length >= 2 and not st.shifted[st.B(0)] and sent_start != 1 @staticmethod cdef int transition(StateC* st, attr_t label) nogil: @@ -137,10 +137,11 @@ cdef class Shift: @staticmethod cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: - if gold.fused_tokens[s.B(1)] == label: - return 0 - else: - return 1 + return 0 + #if gold.fused_tokens[s.B(1)] == label: + # return 0 + #else: + # return 1 cdef class Reduce: @@ -265,7 +266,7 @@ cdef class Break: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: - st.set_break(st.B_(0).l_edge) + st.set_break(0) st.fast_forward() @staticmethod @@ -278,7 +279,7 @@ cdef class Break: cdef int i, j, S_i, B_i for i in range(s.stack_depth()): S_i = s.S(i) - for j in range(s.buffer_length()): + for j in range(s.c.buffer_length): B_i = s.B(j) cost += gold.heads[S_i] == B_i cost += gold.heads[B_i] == S_i diff --git a/spacy/syntax/stateclass.pxd b/spacy/syntax/stateclass.pxd index 0a9be3b7f..8331e674d 100644 --- a/spacy/syntax/stateclass.pxd +++ b/spacy/syntax/stateclass.pxd @@ -10,6 +10,7 @@ from ..vocab cimport EMPTY_LEXEME from ._state cimport StateC +@cython.final cdef class StateClass: cdef Pool mem cdef StateC* c @@ -105,7 +106,7 @@ cdef class StateClass: return self.c.stack_depth() cdef inline int buffer_length(self) nogil: - return self.c.buffer_length() + return self.c.buffer_length cdef inline void push(self) nogil: self.c.push() diff --git a/spacy/syntax/stateclass.pyx b/spacy/syntax/stateclass.pyx index 7c46150bd..cd0cab2a6 100644 --- a/spacy/syntax/stateclass.pyx +++ b/spacy/syntax/stateclass.pyx @@ -8,13 +8,14 @@ from ..tokens.doc cimport Doc cdef class StateClass: - def __init__(self, Doc doc=None, int offset=0): + def __init__(self, Doc doc=None, int offset=0, int max_split=0): cdef Pool mem = Pool() self.mem = mem self._borrowed = 0 if doc is not None: self.c = new StateC(doc.c, doc.length) self.c.offset = offset + self.c.max_split = max_split def __dealloc__(self): if self._borrowed != 1: @@ -39,6 +40,14 @@ cdef class StateClass: if fast_forward: self.c.fast_forward() + def unshift(self, fast_forward=True): + self.c.unshift() + if fast_forward: + self.c.fast_forward() + + def set_break(self, int i): + self.c.set_break(i) + def split_token(self, int i, int n, fast_forward=True): self.c.split(i, n) if fast_forward: @@ -57,7 +66,7 @@ cdef class StateClass: @property def queue(self): - return {self.B(i) for i in range(self.c.buffer_length())} + return [self.B(i) for i in range(self.c.buffer_length)] @property def token_vector_lenth(self): diff --git a/spacy/tests/parser/test_split_word.py b/spacy/tests/parser/test_split_word.py index 1d74f3692..e44a27fdb 100644 --- a/spacy/tests/parser/test_split_word.py +++ b/spacy/tests/parser/test_split_word.py @@ -32,37 +32,12 @@ def test_pop(): assert state.get_S(0) == 0 -def toy_split(): - def _realloc(data, new_size): - additions = new_size - len(data) - return data + ['']*additions - length = 10 - sent = list(range(length)) - sent = [None]*pad + sent + [None]*pad # pad - ptr = pad - i = 5 - n = 2 - - ptr -= pad - i += pad - sent = _realloc(sent, length+n+(pad*2)) - n_moved = (length + (pad*2)) - i+1 - - - def test_split(): '''state.split_token should take the ith word of the buffer, and split it into n+1 pieces. n is 0-indexed, i.e. split(i, 0) is a noop, and split(i, 1) creates 1 new token.''' doc = get_doc('abcd') - state = StateClass(doc) - assert len(state) == len(doc) - state.split_token(1, 2) - assert len(state) == len(doc)+2 - stdoc = state.get_doc(doc.vocab) - assert stdoc[0].text == 'a' - assert stdoc[1].text == 'b' - assert stdoc[2].text == 'b' - assert stdoc[3].text == 'b' - assert stdoc[4].text == 'c' - assert stdoc[5].text == 'd' + state = StateClass(doc, max_split=3) + assert state.queue == [0, 1, 2, 3] + state.split_token(1, 2, fast_forward=False) + assert state.queue == [0, 1, 1*4+1, 2*4+1, 2, 3] diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index c7eac15c0..592dafb13 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -320,7 +320,7 @@ cdef class Doc: break else: return 1.0 - + if self.vector_norm == 0 or other.vector_norm == 0: return 0.0 return numpy.dot(self.vector, other.vector) / (self.vector_norm * other.vector_norm)