From 456c881ae30aa46905962edeb33202ddab01fb45 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 25 Jan 2021 14:40:05 +1100 Subject: [PATCH] Try to fix parser training --- .../_parser_internals/transition_system.pyx | 2 ++ spacy/pipeline/transition_parser.pyx | 15 ++++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index becaedc60..914b4123c 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -83,6 +83,8 @@ cdef class TransitionSystem: def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None): if state.is_final(): return [] + if not self.has_gold(eg): + return [] cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 8b974a486..fbc93a6d3 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -316,8 +316,9 @@ cdef class Parser(TrainablePipe): validate_examples(examples, "Parser.update") for multitask in self._multitasks: multitask.update(examples, drop=drop, sgd=sgd) - - examples = [eg for eg in examples if self.moves.has_gold(eg)] + # We need to take care to act on the whole batch, because we might be + # getting vectors via a listener. + n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) if len(examples) == 0: return losses set_dropout_rate(self.model, drop) @@ -347,7 +348,8 @@ cdef class Parser(TrainablePipe): states, golds, _ = self.moves.init_gold_batch(examples) if not states: return losses - model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples]) + docs = [eg.predicted for eg in examples] + model, backprop_tok2vec = self.model.begin_update(docs) all_states = list(states) states_golds = list(zip(states, golds)) @@ -371,7 +373,6 @@ cdef class Parser(TrainablePipe): backprop_tok2vec(golds) if sgd not in (None, False): self.finish_update(sgd) - docs = [eg.predicted for eg in examples] # If we want to set the annotations based on predictions, it's really # hard to avoid parsing the data twice :(. # The issue is that we cut up the gold batch into sub-states, and that @@ -601,7 +602,7 @@ cdef class Parser(TrainablePipe): states = [] golds = [] for state, eg, history in zip(all_states, examples, oracle_histories): - if state.is_final(): + if not history: continue gold = self.moves.init_gold(state, eg) if len(history) < max_length: @@ -609,6 +610,8 @@ cdef class Parser(TrainablePipe): golds.append(gold) continue for i in range(0, len(history), max_length): + if state.is_final(): + break start_state = state.copy() for clas in history[i:i+max_length]: action = self.moves.c[clas] @@ -618,6 +621,4 @@ cdef class Parser(TrainablePipe): if self.moves.has_gold(eg, start_state.B(0), state.B(0)): states.append(start_state) golds.append(gold) - if state.is_final(): - break return states, golds, max_length