diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd index d31430124..0d0dd8c05 100644 --- a/spacy/pipeline/_parser_internals/_state.pxd +++ b/spacy/pipeline/_parser_internals/_state.pxd @@ -42,7 +42,6 @@ cdef cppclass StateC: RingBufferC _hist int length int offset - int n_pushes int _s_i int _b_i int _e_i @@ -50,7 +49,6 @@ cdef cppclass StateC: __init__(const TokenC* sent, int length) nogil: cdef int PADDING = 5 - this.n_pushes = 0 this._buffer = calloc(length + (PADDING * 2), sizeof(int)) this._stack = calloc(length + (PADDING * 2), sizeof(int)) this.shifted = calloc(length + (PADDING * 2), sizeof(bint)) @@ -337,7 +335,6 @@ cdef cppclass StateC: this.set_break(this.B_(0).l_edge) if this._b_i > this._break: this._break = -1 - this.n_pushes += 1 void pop() nogil: if this._s_i >= 1: @@ -354,7 +351,6 @@ cdef cppclass StateC: this._buffer[this._b_i] = this.S(0) this._s_i -= 1 this.shifted[this.B(0)] = True - this.n_pushes -= 1 void add_arc(int head, int child, attr_t label) nogil: if this.has_head(child): @@ -435,7 +431,6 @@ cdef cppclass StateC: this._break = src._break this.offset = src.offset this._empty_token = src._empty_token - this.n_pushes = src.n_pushes void fast_forward() nogil: # space token attachement policy: diff --git a/spacy/pipeline/_parser_internals/stateclass.pyx b/spacy/pipeline/_parser_internals/stateclass.pyx index d59ade467..880cf6cc5 100644 --- a/spacy/pipeline/_parser_internals/stateclass.pyx +++ b/spacy/pipeline/_parser_internals/stateclass.pyx @@ -36,10 +36,6 @@ cdef class StateClass: hist[i] = self.c.get_hist(i+1) return hist - @property - def n_pushes(self): - return self.c.n_pushes - def is_final(self): return self.c.is_final() diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 2169b4c17..5a6b491e0 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -6,6 +6,7 @@ from itertools import islice from libcpp.vector cimport vector from libc.string cimport memset from libc.stdlib cimport calloc, free +import random import srsly from thinc.api import set_dropout_rate @@ -275,22 +276,22 @@ cdef class Parser(Pipe): # Prepare the stepwise model, and get the callback for finishing the batch model, backprop_tok2vec = self.model.begin_update( [eg.predicted for eg in examples]) - if self.cfg["update_with_oracle_cut_size"] >= 1: - # Chop sequences into lengths of this many transitions, to make the + max_moves = self.cfg["update_with_oracle_cut_size"] + if max_moves >= 1: + # Chop sequences into lengths of this many words, to make the # batch uniform length. - # We used to randomize this, but it's not clear that actually helps? - max_pushes = self.cfg["update_with_oracle_cut_size"] + max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) states, golds, _ = self._init_gold_batch( examples, - max_length=max_pushes + max_length=max_moves ) else: states, golds, _ = self.moves.init_gold_batch(examples) - max_pushes = max([len(eg.x) for eg in examples]) if not states: return losses all_states = list(states) states_golds = list(zip(states, golds)) + n_moves = 0 while states_golds: states, golds = zip(*states_golds) scores, backprop = model.begin_update(states) @@ -302,8 +303,10 @@ cdef class Parser(Pipe): backprop(d_scores) # Follow the predicted action self.transition_states(states, scores) - states_golds = [(s, g) for (s, g) in zip(states, golds) - if s.n_pushes < max_pushes and not s.is_final()] + states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] + if max_moves >= 1 and n_moves >= max_moves: + break + n_moves += 1 backprop_tok2vec(golds) if sgd not in (None, False): @@ -499,7 +502,7 @@ cdef class Parser(Pipe): raise ValueError(Errors.E149) from None return self - def _init_gold_batch(self, examples, min_length=5, max_length=500): + def _init_gold_batch(self, examples, max_length): """Make a square batch, of length equal to the shortest transition sequence or a cap. A long doc will get multiple states. Let's say we have a doc of length 2*N, @@ -512,8 +515,7 @@ cdef class Parser(Pipe): all_states = self.moves.init_batch([eg.predicted for eg in examples]) states = [] golds = [] - kept = [] - max_length_seen = 0 + to_cut = [] for state, eg in zip(all_states, examples): if self.moves.has_gold(eg) and not state.is_final(): gold = self.moves.init_gold(state, eg) @@ -523,30 +525,22 @@ cdef class Parser(Pipe): else: oracle_actions = self.moves.get_oracle_sequence_from_state( state.copy(), gold) - kept.append((eg, state, gold, oracle_actions)) - min_length = min(min_length, len(oracle_actions)) - max_length_seen = max(max_length, len(oracle_actions)) - if not kept: + to_cut.append((eg, state, gold, oracle_actions)) + if not to_cut: return states, golds, 0 - max_length = max(min_length, min(max_length, max_length_seen)) cdef int clas - max_moves = 0 - for eg, state, gold, oracle_actions in kept: + for eg, state, gold, oracle_actions in to_cut: for i in range(0, len(oracle_actions), max_length): start_state = state.copy() - n_moves = 0 for clas in oracle_actions[i:i+max_length]: action = self.moves.c[clas] action.do(state.c, action.label) state.c.push_hist(action.clas) - n_moves += 1 if state.is_final(): break - max_moves = max(max_moves, n_moves) if self.moves.has_gold(eg, start_state.B(0), state.B(0)): states.append(start_state) golds.append(gold) - max_moves = max(max_moves, n_moves) if state.is_final(): break - return states, golds, max_moves + return states, golds, max_length