diff --git a/spacy/pipeline/_parser_internals/_state.pxd b/spacy/pipeline/_parser_internals/_state.pxd index a6bf926f9..7f644a151 100644 --- a/spacy/pipeline/_parser_internals/_state.pxd +++ b/spacy/pipeline/_parser_internals/_state.pxd @@ -32,6 +32,7 @@ cdef cppclass StateC: vector[ArcC] _left_arcs vector[ArcC] _right_arcs vector[libcpp.bool] _unshiftable + vector[int] history set[int] _sent_starts TokenC _empty_token int length @@ -382,3 +383,4 @@ cdef cppclass StateC: this._b_i = src._b_i this.offset = src.offset this._empty_token = src._empty_token + this.history = src.history diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 03cb8a4d7..b477891f8 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -844,6 +844,7 @@ cdef class ArcEager(TransitionSystem): state.print_state() ))) action.do(state.c, action.label) + state.c.history.push_back(i) break else: failed = False diff --git a/spacy/pipeline/_parser_internals/stateclass.pyx b/spacy/pipeline/_parser_internals/stateclass.pyx index 4eaddd997..208cf061e 100644 --- a/spacy/pipeline/_parser_internals/stateclass.pyx +++ b/spacy/pipeline/_parser_internals/stateclass.pyx @@ -20,6 +20,10 @@ cdef class StateClass: if self._borrowed != 1: del self.c + @property + def history(self): + return list(self.c.history) + @property def stack(self): return [self.S(i) for i in range(self.c.stack_depth())] diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index 5bc92f161..181cffd8d 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -67,6 +67,7 @@ cdef class TransitionSystem: for clas in history: action = self.c[clas] action.do(state.c, action.label) + state.c.history.push_back(clas) return state def get_oracle_sequence(self, Example example, _debug=False): @@ -110,6 +111,7 @@ cdef class TransitionSystem: "S0 head?", str(state.has_head(state.S(0))), ))) action.do(state.c, action.label) + state.c.history.push_back(i) break else: if _debug: @@ -137,6 +139,7 @@ cdef class TransitionSystem: raise ValueError(Errors.E170.format(name=name)) action = self.lookup_transition(name) action.do(state.c, action.label) + state.c.history.push_back(action.clas) cdef Transition lookup_transition(self, object name) except *: raise NotImplementedError diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index fbc93a6d3..3c5e5e9f9 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -203,15 +203,21 @@ cdef class Parser(TrainablePipe): ) def greedy_parse(self, docs, drop=0.): - cdef vector[StateC*] states - cdef StateClass state set_dropout_rate(self.model, drop) - batch = self.moves.init_batch(docs) # This is pretty dirty, but the NER can resize itself in init_batch, # if labels are missing. We therefore have to check whether we need to # expand our model output. self._resize() model = self.model.predict(docs) + batch = self.moves.init_batch(docs) + states = self._predict_states(model, batch) + model.clear_memory() + del model + return states + + def _predict_states(self, model, batch): + cdef vector[StateC*] states + cdef StateClass state weights = get_c_weights(model) for state in batch: if not state.is_final(): @@ -220,8 +226,6 @@ cdef class Parser(TrainablePipe): with nogil: self._parseC(&states[0], weights, sizes) - model.clear_memory() - del model return batch def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): @@ -306,6 +310,7 @@ cdef class Parser(TrainablePipe): else: action = self.moves.c[guess] action.do(states[i], action.label) + states[i].history.push_back(guess) free(is_valid) def update(self, examples, *, drop=0., sgd=None, losses=None): @@ -319,7 +324,7 @@ cdef class Parser(TrainablePipe): # We need to take care to act on the whole batch, because we might be # getting vectors via a listener. n_examples = len([eg for eg in examples if self.moves.has_gold(eg)]) - if len(examples) == 0: + if n_examples == 0: return losses set_dropout_rate(self.model, drop) # The probability we use beam update, instead of falling back to @@ -333,7 +338,11 @@ cdef class Parser(TrainablePipe): losses=losses, beam_density=self.cfg["beam_density"] ) - oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples] + model, backprop_tok2vec = self.model.begin_update([eg.x for eg in examples]) + final_states = self.moves.init_batch([eg.x for eg in examples]) + self._predict_states(model, final_states) + histories = [list(state.history) for state in final_states] + #oracle_histories = [self.moves.get_oracle_sequence(eg) for eg in examples] max_moves = self.cfg["update_with_oracle_cut_size"] if max_moves >= 1: # Chop sequences into lengths of this many words, to make the @@ -341,15 +350,13 @@ cdef class Parser(TrainablePipe): max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) states, golds, _ = self._init_gold_batch( examples, - oracle_histories, + histories, max_length=max_moves ) else: states, golds, _ = self.moves.init_gold_batch(examples) if not states: return losses - docs = [eg.predicted for eg in examples] - model, backprop_tok2vec = self.model.begin_update(docs) all_states = list(states) states_golds = list(zip(states, golds)) @@ -373,15 +380,7 @@ cdef class Parser(TrainablePipe): backprop_tok2vec(golds) if sgd not in (None, False): self.finish_update(sgd) - # If we want to set the annotations based on predictions, it's really - # hard to avoid parsing the data twice :(. - # The issue is that we cut up the gold batch into sub-states, and that - # means there's no one predicted sequence during the update. - gold_states = [ - self.moves.follow_history(doc, history) - for doc, history in zip(docs, oracle_histories) - ] - self.set_annotations(docs, gold_states) + self.set_annotations([eg.x for eg in examples], final_states) # Ugh, this is annoying. If we're working on GPU, we want to free the # memory ASAP. It seems that Python doesn't necessarily get around to # removing these in time if we don't explicitly delete? It's confusing. @@ -599,6 +598,7 @@ cdef class Parser(TrainablePipe): StateClass state Transition action all_states = self.moves.init_batch([eg.predicted for eg in examples]) + assert len(all_states) == len(examples) == len(oracle_histories) states = [] golds = [] for state, eg, history in zip(all_states, examples, oracle_histories): @@ -616,6 +616,7 @@ cdef class Parser(TrainablePipe): for clas in history[i:i+max_length]: action = self.moves.c[clas] action.do(state.c, action.label) + state.c.history.push_back(clas) if state.is_final(): break if self.moves.has_gold(eg, start_state.B(0), state.B(0)):