diff --git a/spacy/syntax/_state.pxd b/spacy/syntax/_state.pxd index a95a1910f..65ff33def 100644 --- a/spacy/syntax/_state.pxd +++ b/spacy/syntax/_state.pxd @@ -319,6 +319,21 @@ cdef cppclass StateC: if this._b_i > this._break: this._break = -1 + void split(int i, int n) nogil: + '''Split token i of the buffer into N pieces.''' + # Let's say we've got a length 10 sentence. + # state.split(5, 2) + # Before: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # After: [0, 1, 2, 3, 4, 5.0, 5.1, 5.2, 6, 7, 8, 9, 10] + # Sentence grows to length 12. + this.length += n + this._sent -= PADDING + this._sent = realloc(this.length + (PADDING * 2), sizeof(TokenC)) + this._sent += PADDING + # Words 6-10 move to positions 8-12 + memmove(&this._sent[i+1], &this._sent[i+1+n], (this.length-i)+PADDING*sizeof(TokenC)) + # Words 0-5 stay where they are. + void pop() nogil: if this._s_i >= 1: this._s_i -= 1 diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 28e1a0292..a5af8bb7c 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -122,6 +122,8 @@ cdef class Shift: @staticmethod cdef int transition(StateC* st, attr_t label) nogil: + if label != 0: + st.split(st.B(1), label) st.push() st.fast_forward() @@ -135,7 +137,10 @@ cdef class Shift: @staticmethod cdef inline weight_t label_cost(StateClass s, const GoldParseC* gold, attr_t label) nogil: - return 0 + if gold.fused_tokens[s.B(1)] == label: + return 0 + else: + return 1 cdef class Reduce: