diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 9949c0ef3..ceaea3c9c 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -265,11 +265,15 @@ cdef class Parser: free(is_valid) def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None): + cdef StateClass state if losses is None: losses = {} losses.setdefault(self.name, 0.) for multitask in self._multitasks: multitask.update(examples, drop=drop, sgd=sgd) + n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) + if n_examples == 0: + return losses set_dropout_rate(self.model, drop) # Prepare the stepwise model, and get the callback for finishing the batch model, backprop_tok2vec = self.model.begin_update( @@ -280,10 +284,13 @@ cdef class Parser: cut_size = self.cfg["update_with_oracle_cut_size"] states, golds, max_steps = self._init_gold_batch( examples, - max_length=numpy.random.choice(range(20, cut_size)) + max_length=numpy.random.choice(range(5, cut_size)) ) else: - states, golds, max_steps = self.moves.init_gold_batch(examples) + states, golds, _ = self.moves.init_gold_batch(examples) + max_steps = max([len(eg.x) for eg in examples]) + if not states: + return losses all_states = list(states) states_golds = zip(states, golds) for _ in range(max_steps): @@ -292,6 +299,17 @@ cdef class Parser: states, golds = zip(*states_golds) scores, backprop = model.begin_update(states) d_scores = self.get_batch_loss(states, golds, scores, losses) + if self.cfg["normalize_gradients_with_batch_size"]: + # We have to be very careful how we do this, because of the way we + # cut up the batch. We subdivide long sequences. If we normalize + # naively, we end up normalizing by sequence length, which + # is bad: that would mean that states in long sequences + # consistently get smaller gradients. Imagine if we have two + # sequences, one length 1000, one length 20. If we cut up + # the 1k sequence so that we have a "batch" of 50 subsequences, + # we don't want the gradients to get 50 times smaller! + d_scores /= n_examples + backprop(d_scores) # Follow the predicted action self.transition_states(states, scores) @@ -389,8 +407,6 @@ cdef class Parser: cpu_log_loss(c_d_scores, costs, is_valid, &scores[i, 0], d_scores.shape[1]) c_d_scores += d_scores.shape[1] - if len(states) and self.cfg["normalize_gradients_with_batch_size"]: - d_scores /= len(states) if losses is not None: losses.setdefault(self.name, 0.) losses[self.name] += (d_scores**2).sum() @@ -503,41 +519,49 @@ cdef class Parser: return self def _init_gold_batch(self, examples, min_length=5, max_length=500): - """Make a square batch, of length equal to the shortest doc. A long + """Make a square batch, of length equal to the shortest transition + sequence or a cap. A long doc will get multiple states. Let's say we have a doc of length 2*N, where N is the shortest doc. We'll make two states, one representing long_doc[:N], and another representing long_doc[N:].""" cdef: + StateClass start_state StateClass state Transition action all_states = self.moves.init_batch([eg.predicted for eg in examples]) kept = [] + max_length_seen = 0 for state, eg in zip(all_states, examples): if self.moves.has_gold(eg) and not state.is_final(): gold = self.moves.init_gold(state, eg) - kept.append((eg, state, gold)) - max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples]))) - max_moves = 0 + oracle_actions = self.moves.get_oracle_sequence_from_state( + state.copy(), gold) + kept.append((eg, state, gold, oracle_actions)) + min_length = min(min_length, len(oracle_actions)) + max_length_seen = max(max_length, len(oracle_actions)) + if not kept: + return [], [], 0 + max_length = max(min_length, min(max_length, max_length_seen)) states = [] golds = [] - for eg, state, gold in kept: - oracle_actions = self.moves.get_oracle_sequence_from_state( - state, gold) - start = 0 - while start < len(eg.predicted): - state = state.copy() + cdef int clas + max_moves = 0 + for eg, state, gold, oracle_actions in kept: + for i in range(0, len(oracle_actions), max_length): + start_state = state.copy() n_moves = 0 - while state.B(0) < start and not state.is_final(): - action = self.moves.c[oracle_actions.pop(0)] + for clas in oracle_actions[i:i+max_length]: + action = self.moves.c[clas] action.do(state.c, action.label) state.c.push_hist(action.clas) n_moves += 1 - has_gold = self.moves.has_gold(eg, start=start, - end=start+max_length) - if not state.is_final() and has_gold: - states.append(state) + if state.is_final(): + break + max_moves = max(max_moves, n_moves) + if self.moves.has_gold(eg, start_state.B(0), state.B(0)): + states.append(start_state) golds.append(gold) max_moves = max(max_moves, n_moves) - start += min(max_length, len(eg.x)-start) - max_moves = max(max_moves, len(oracle_actions)) + if state.is_final(): + break return states, golds, max_moves