diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd index 0d0dd8c05..d31430124 100644 --- a/spacy/pipeline/_parser_internals/_state.pxd +++ b/spacy/pipeline/_parser_internals/_state.pxd @@ -42,6 +42,7 @@ cdef cppclass StateC: RingBufferC _hist int length int offset + int n_pushes int _s_i int _b_i int _e_i @@ -49,6 +50,7 @@ cdef cppclass StateC: __init__(const TokenC* sent, int length) nogil: cdef int PADDING = 5 + this.n_pushes = 0 this._buffer = calloc(length + (PADDING * 2), sizeof(int)) this._stack = calloc(length + (PADDING * 2), sizeof(int)) this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) @@ -335,6 +337,7 @@ cdef cppclass StateC: this.set_break(this.B_(0).l_edge) if this._b_i > this._break: this._break = -1 + this.n_pushes += 1 void pop() nogil: if this._s_i >= 1: @@ -351,6 +354,7 @@ cdef cppclass StateC: this._buffer[this._b_i] = this.S(0) this._s_i -= 1 this.shifted[this.B(0)] = True + this.n_pushes -= 1 void add_arc(int head, int child, attr_t label) nogil: if this.has_head(child): @@ -431,6 +435,7 @@ cdef cppclass StateC: this._break = src._break this.offset = src.offset this._empty_token = src._empty_token + this.n_pushes = src.n_pushes void fast_forward() nogil: # space token attachement policy: diff --git a/spacy/pipeline/_parser_internals/stateclass.pyx b/spacy/pipeline/_parser_internals/stateclass.pyx index 880cf6cc5..d59ade467 100644 --- a/spacy/pipeline/_parser_internals/stateclass.pyx +++ b/spacy/pipeline/_parser_internals/stateclass.pyx @@ -36,6 +36,10 @@ cdef class StateClass: hist[i] = self.c.get_hist(i+1) return hist + @property + def n_pushes(self): + return self.c.n_pushes + def is_final(self): return self.c.is_final() diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 2eadfa6aa..2169b4c17 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -279,14 +279,14 @@ cdef class Parser(Pipe): # Chop sequences into lengths of this many transitions, to make the # batch uniform length. # We used to randomize this, but it's not clear that actually helps? - cut_size = self.cfg["update_with_oracle_cut_size"] - states, golds, max_steps = self._init_gold_batch( + max_pushes = self.cfg["update_with_oracle_cut_size"] + states, golds, _ = self._init_gold_batch( examples, - max_length=cut_size + max_length=max_pushes ) else: states, golds, _ = self.moves.init_gold_batch(examples) - max_steps = max([len(eg.x) for eg in examples]) + max_pushes = max([len(eg.x) for eg in examples]) if not states: return losses all_states = list(states) @@ -302,7 +302,8 @@ cdef class Parser(Pipe): backprop(d_scores) # Follow the predicted action self.transition_states(states, scores) - states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] + states_golds = [(s, g) for (s, g) in zip(states, golds) + if s.n_pushes < max_pushes and not s.is_final()] backprop_tok2vec(golds) if sgd not in (None, False):