From cd9194c823c6f594a0f1e8a5a4bd33e16039cbf9 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 4 Sep 2020 03:52:29 +0200 Subject: [PATCH] Try different oracle cuts --- spacy/pipeline/transition_parser.pyx | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 5a6b491e0..d46bbb150 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -281,7 +281,7 @@ cdef class Parser(Pipe): # Chop sequences into lengths of this many words, to make the # batch uniform length. max_moves = int(random.uniform(max_moves // 2, max_moves * 2)) - states, golds, _ = self._init_gold_batch( + states, golds, max_moves = self._init_gold_batch( examples, max_length=max_moves ) @@ -304,10 +304,9 @@ cdef class Parser(Pipe): # Follow the predicted action self.transition_states(states, scores) states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] + n_moves += 1 if max_moves >= 1 and n_moves >= max_moves: break - n_moves += 1 - backprop_tok2vec(golds) if sgd not in (None, False): self.model.finish_update(sgd) @@ -513,21 +512,19 @@ cdef class Parser(Pipe): StateClass state Transition action all_states = self.moves.init_batch([eg.predicted for eg in examples]) - states = [] - golds = [] to_cut = [] for state, eg in zip(all_states, examples): if self.moves.has_gold(eg) and not state.is_final(): gold = self.moves.init_gold(state, eg) - if len(eg.x) < max_length: - states.append(state) - golds.append(gold) - else: - oracle_actions = self.moves.get_oracle_sequence_from_state( - state.copy(), gold) - to_cut.append((eg, state, gold, oracle_actions)) + oracle_actions = self.moves.get_oracle_sequence_from_state( + state.copy(), gold) + to_cut.append((eg, state, gold, oracle_actions)) + states = [] + golds = [] if not to_cut: return states, golds, 0 + lengths = [len(x[-1]) for x in to_cut] + max_length = min(min(lengths), max_length) cdef int clas for eg, state, gold, oracle_actions in to_cut: for i in range(0, len(oracle_actions), max_length):