diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 3d398e6c8..fcc05de3f 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -576,15 +576,20 @@ cdef class ArcEager(TransitionSystem): def is_gold_parse(self, StateClass state, gold): raise NotImplementedError + def init_gold(self, StateClass state, Example example): + gold = ArcEagerGold(self, state, example) + self._replace_unseen_labels(gold) + return gold + def init_gold_batch(self, examples): - states = self.init_batch([eg.predicted for eg in examples]) - keeps = [i for i, (eg, s) in enumerate(zip(examples, states)) - if self.has_gold(eg) and not s.is_final()] - golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] - states = [states[i] for i in keeps] - for gold in golds: - self._replace_unseen_labels(gold) - n_steps = sum([len(s.queue) * 4 for s in states]) + all_states = self.init_batch([eg.predicted for eg in examples]) + golds = [] + states = [] + for state, eg in zip(all_states, examples): + if self.has_gold(eg) and not state.is_final(): + golds.append(self.init_gold(state, eg)) + states.append(state) + n_steps = sum([len(s.queue) for s in states]) return states, golds, n_steps def _replace_unseen_labels(self, ArcEagerGold gold): @@ -684,8 +689,12 @@ cdef class ArcEager(TransitionSystem): doc.is_parsed = True set_children_from_heads(doc.c, doc.length) - def has_gold(self, Example eg): - return eg.y.is_parsed + def has_gold(self, Example eg, start=0, end=None): + for word in eg.y[start:end]: + if word.dep != 0: + return True + else: + return False cdef int set_valid(self, int* output, const StateC* st) nogil: cdef bint[N_MOVES] is_valid diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 9f686993d..c4125bbdf 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -130,11 +130,13 @@ cdef class BiluoPushDown(TransitionSystem): return MOVE_NAMES[move] + '-' + self.strings[label] def init_gold_batch(self, examples): - states = self.init_batch([eg.predicted for eg in examples]) - keeps = [i for i, (s, eg) in enumerate(zip(states, examples)) - if not s.is_final() and self.has_gold(eg)] - golds = [BiluoGold(self, states[i], examples[i]) for i in keeps] - states = [states[i] for i in keeps] + all_states = self.init_batch([eg.predicted for eg in examples]) + golds = [] + states = [] + for state, eg in zip(all_states, examples): + if self.has_gold(eg) and not state.is_final(): + golds.append(self.init_gold(state, eg)) + states.append(state) n_steps = sum([len(s.queue) for s in states]) return states, golds, n_steps @@ -237,8 +239,15 @@ cdef class BiluoPushDown(TransitionSystem): self.add_action(UNIT, st._sent[i].ent_type) self.add_action(LAST, st._sent[i].ent_type) - def has_gold(self, Example eg): - return eg.y.is_nered + def init_gold(self, StateClass state, Example example): + return BiluoGold(self, state, example) + + def has_gold(self, Example eg, start=0, end=None): + for word in eg.y[start:end]: + if word.ent_iob != 0: + return True + else: + return False def get_cost(self, StateClass stcls, gold, int i): if not isinstance(gold, BiluoGold): diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 65039a7e5..1f28130fb 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -272,7 +272,7 @@ cdef class Parser: # Prepare the stepwise model, and get the callback for finishing the batch model, backprop_tok2vec = self.model.begin_update( [eg.predicted for eg in examples]) - states, golds, max_steps = self.moves.init_gold_batch(examples) + states, golds, max_steps = self._init_gold_batch(examples) all_states = list(states) states_golds = zip(states, golds) for _ in range(max_steps): @@ -490,3 +490,42 @@ cdef class Parser: except AttributeError: raise ValueError(Errors.E149) return self + + def _init_gold_batch(self, examples, min_length=5, max_length=500): + """Make a square batch, of length equal to the shortest doc. A long + doc will get multiple states. Let's say we have a doc of length 2*N, + where N is the shortest doc. We'll make two states, one representing + long_doc[:N], and another representing long_doc[N:].""" + cdef: + StateClass state + Transition action + all_states = self.moves.init_batch([eg.predicted for eg in examples]) + kept = [] + for state, eg in zip(all_states, examples): + if self.moves.has_gold(eg) and not state.is_final(): + gold = self.moves.init_gold(state, eg) + kept.append((eg, state, gold)) + max_length = max(min_length, min(max_length, min([len(eg.x) for eg in examples]))) + max_moves = 0 + states = [] + golds = [] + for eg, state, gold in kept: + oracle_actions = self.moves.get_oracle_sequence(eg) + start = 0 + while start < len(eg.predicted): + state = state.copy() + n_moves = 0 + while state.B(0) < start and not state.is_final(): + action = self.moves.c[oracle_actions.pop(0)] + action.do(state.c, action.label) + state.c.push_hist(action.clas) + n_moves += 1 + has_gold = self.moves.has_gold(eg, start=start, + end=start+max_length) + if not state.is_final() and has_gold: + states.append(state) + golds.append(gold) + max_moves = max(max_moves, n_moves) + start += min(max_length, len(eg.x)-start) + max_moves = max(max_moves, len(oracle_actions)) + return states, golds, max_moves