diff --git a/spacy/pipeline/_parser_internals/_beam_utils.pyx b/spacy/pipeline/_parser_internals/_beam_utils.pyx index ef4165505..fa7df2056 100644 --- a/spacy/pipeline/_parser_internals/_beam_utils.pyx +++ b/spacy/pipeline/_parser_internals/_beam_utils.pyx @@ -193,11 +193,7 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de for i, (d_scores, bp_scores) in enumerate(zip(states_d_scores, backprops)): loss += (d_scores**2).mean() bp_scores(d_scores) - # Return the predicted sequence for each doc. - predicted_histories = [] - for i in range(len(pbeam)): - predicted_histories.append(pbeam[i].histories[0]) - return predicted_histories, loss + return loss def collect_states(beams, docs): diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 7c3d6d275..069b41170 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -638,17 +638,16 @@ cdef class ArcEager(TransitionSystem): return gold def init_gold_batch(self, examples): + # TODO: Projectivity? all_states = self.init_batch([eg.predicted for eg in examples]) golds = [] states = [] - docs = [] for state, eg in zip(all_states, examples): if self.has_gold(eg) and not state.is_final(): golds.append(self.init_gold(state, eg)) states.append(state) - docs.append(eg.x) n_steps = sum([len(s.queue) for s in states]) - return states, golds, docs + return states, golds, n_steps def _replace_unseen_labels(self, ArcEagerGold gold): backoff_label = self.strings["dep"] diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index 287513a79..9bb4f7f5f 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -120,16 +120,6 @@ cdef class TransitionSystem: raise ValueError(Errors.E024) return history - def follow_history(self, doc, history): - """Get the state that results from following a sequence of actions.""" - cdef int clas - cdef StateClass state - state = self.init_batch([doc])[0] - for clas in history: - action = self.c[clas] - action.do(state.c, action.label) - return state - def apply_transition(self, StateClass state, name): if not self.is_valid(state, name): raise ValueError(Errors.E170.format(name=name)) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index b93565178..15b07e9b1 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -337,22 +337,21 @@ cdef class Parser(TrainablePipe): # Chop sequences into lengths of this many words, to make the # batch uniform length. max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) - states, golds, max_moves, state2doc = self._init_gold_batch( + states, golds, _ = self._init_gold_batch( examples, max_length=max_moves ) else: - states, golds, state2doc = self.moves.init_gold_batch(examples) + states, golds, _ = self.moves.init_gold_batch(examples) if not states: return losses model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples]) - histories = [[] for example in examples] all_states = list(states) - states_golds = list(zip(states, golds, state2doc)) + states_golds = list(zip(states, golds)) n_moves = 0 while states_golds: - states, golds, state2doc = zip(*states_golds) + states, golds = zip(*states_golds) scores, backprop = model.begin_update(states) d_scores = self.get_batch_loss(states, golds, scores, losses) # Note that the gradient isn't normalized by the batch size @@ -361,13 +360,8 @@ cdef class Parser(TrainablePipe): # be getting smaller gradients for states in long sequences. backprop(d_scores) # Follow the predicted action - actions = self.transition_states(states, scores) - for i, action in enumerate(actions): - histories[i].append(action) - states_golds = [ - s for s in zip(states, golds, state2doc) - if not s[0].is_final() - ] + self.transition_states(states, scores) + states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] if max_moves >= 1 and n_moves >= max_moves: break n_moves += 1 @@ -376,11 +370,11 @@ cdef class Parser(TrainablePipe): if sgd not in (None, False): self.finish_update(sgd) docs = [eg.predicted for eg in examples] - states = [ - self.moves.follow_history(doc, history) - for doc, history in zip(docs, histories) - ] - self.set_annotations(docs, self._get_states(docs, states)) + # TODO: Refactor so we don't have to parse twice like this (ugh) + # The issue is that we cut up the gold batch into sub-states, and that + # makes it hard to get the actual predicted transition sequence. + predicted_states = self.predict(docs) + self.set_annotations(docs, predicted_states) # Ugh, this is annoying. If we're working on GPU, we want to free the # memory ASAP. It seems that Python doesn't necessarily get around to # removing these in time if we don't explicitly delete? It's confusing. @@ -441,16 +435,13 @@ cdef class Parser(TrainablePipe): def update_beam(self, examples, *, beam_width, drop=0., sgd=None, losses=None, beam_density=0.0): - if losses is None: - losses = {} - losses.setdefault(self.name, 0.0) - states, golds, docs = self.moves.init_gold_batch(examples) + states, golds, _ = self.moves.init_gold_batch(examples) if not states: return losses # Prepare the stepwise model, and get the callback for finishing the batch model, backprop_tok2vec = self.model.begin_update( [eg.predicted for eg in examples]) - predicted_histories, loss = _beam_utils.update_beam( + loss = _beam_utils.update_beam( self.moves, states, golds, @@ -462,12 +453,6 @@ cdef class Parser(TrainablePipe): backprop_tok2vec(golds) if sgd is not None: self.finish_update(sgd) - states = [ - self.moves.follow_history(doc, history) - for doc, history in zip(docs, predicted_histories) - ] - self.set_annotations(docs, states) - return losses def get_batch_loss(self, states, golds, float[:, ::1] scores, losses): cdef StateClass state @@ -610,24 +595,18 @@ cdef class Parser(TrainablePipe): states = [] golds = [] to_cut = [] - # Return a list indicating the position in the batch that each state - # refers to. This lets us put together the full list of predicted - # histories. - state2doc = [] - doc2i = {eg.x: i for i, eg in enumerate(examples)} for state, eg in zip(all_states, examples): if self.moves.has_gold(eg) and not state.is_final(): gold = self.moves.init_gold(state, eg) if len(eg.x) < max_length: states.append(state) golds.append(gold) - state2doc.append(doc2i[eg.x]) else: oracle_actions = self.moves.get_oracle_sequence_from_state( state.copy(), gold) to_cut.append((eg, state, gold, oracle_actions)) if not to_cut: - return states, golds, 0, state2doc + return states, golds, 0 cdef int clas for eg, state, gold, oracle_actions in to_cut: for i in range(0, len(oracle_actions), max_length): @@ -640,7 +619,6 @@ cdef class Parser(TrainablePipe): if self.moves.has_gold(eg, start_state.B(0), state.B(0)): states.append(start_state) golds.append(gold) - state2doc.append(doc2i[eg.x]) if state.is_final(): break - return states, golds, max_length, state2doc + return states, golds, max_length