diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index 9bb4f7f5f..61c4544e1 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -61,6 +61,14 @@ cdef class TransitionSystem: offset += len(doc) return states + def follow_history(self, doc, history): + cdef int clas + cdef StateClass state = StateClass(doc) + for clas in history: + action = self.c[clas] + action.do(state.c, action.label) + return state + def get_oracle_sequence(self, Example example, _debug=False): states, golds, _ = self.init_gold_batch([example]) if not states: diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 15b07e9b1..8b974a486 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -317,8 +317,8 @@ cdef class Parser(TrainablePipe): for multitask in self._multitasks: multitask.update(examples, drop=drop, sgd=sgd) - n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) - if n_examples == 0: + examples = [eg for eg in examples if self.moves.has_gold(eg)] + if len(examples) == 0: return losses set_dropout_rate(self.model, drop) # The probability we use beam update, instead of falling back to @@ -332,6 +332,7 @@ cdef class Parser(TrainablePipe): losses=losses, beam_density=self.cfg["beam_density"] ) + oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples] max_moves = self.cfg["update_with_oracle_cut_size"] if max_moves >= 1: # Chop sequences into lengths of this many words, to make the @@ -339,6 +340,7 @@ cdef class Parser(TrainablePipe): max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) states, golds, _ = self._init_gold_batch( examples, + oracle_histories, max_length=max_moves ) else: @@ -370,11 +372,15 @@ cdef class Parser(TrainablePipe): if sgd not in (None, False): self.finish_update(sgd) docs = [eg.predicted for eg in examples] - # TODO: Refactor so we don't have to parse twice like this (ugh) + # If we want to set the annotations based on predictions, it's really + # hard to avoid parsing the data twice :(. # The issue is that we cut up the gold batch into sub-states, and that - # makes it hard to get the actual predicted transition sequence. - predicted_states = self.predict(docs) - self.set_annotations(docs, predicted_states) + # means there's no one predicted sequence during the update. + gold_states = [ + self.moves.follow_history(doc, history) + for doc, history in zip(docs, oracle_histories) + ] + self.set_annotations(docs, gold_states) # Ugh, this is annoying. If we're working on GPU, we want to free the # memory ASAP. It seems that Python doesn't necessarily get around to # removing these in time if we don't explicitly delete? It's confusing. @@ -581,7 +587,7 @@ cdef class Parser(TrainablePipe): raise ValueError(Errors.E149) from None return self - def _init_gold_batch(self, examples, max_length): + def _init_gold_batch(self, examples, oracle_histories, max_length): """Make a square batch, of length equal to the shortest transition sequence or a cap. A long doc will get multiple states. Let's say we have a doc of length 2*N, @@ -594,24 +600,17 @@ cdef class Parser(TrainablePipe): all_states = self.moves.init_batch([eg.predicted for eg in examples]) states = [] golds = [] - to_cut = [] - for state, eg in zip(all_states, examples): - if self.moves.has_gold(eg) and not state.is_final(): - gold = self.moves.init_gold(state, eg) - if len(eg.x) < max_length: - states.append(state) - golds.append(gold) - else: - oracle_actions = self.moves.get_oracle_sequence_from_state( - state.copy(), gold) - to_cut.append((eg, state, gold, oracle_actions)) - if not to_cut: - return states, golds, 0 - cdef int clas - for eg, state, gold, oracle_actions in to_cut: - for i in range(0, len(oracle_actions), max_length): + for state, eg, history in zip(all_states, examples, oracle_histories): + if state.is_final(): + continue + gold = self.moves.init_gold(state, eg) + if len(history) < max_length: + states.append(state) + golds.append(gold) + continue + for i in range(0, len(history), max_length): start_state = state.copy() - for clas in oracle_actions[i:i+max_length]: + for clas in history[i:i+max_length]: action = self.moves.c[clas] action.do(state.c, action.label) if state.is_final():