diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index 8a638f31d..2daa7986e 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -577,15 +577,16 @@ cdef class ArcEager(TransitionSystem): raise NotImplementedError def init_gold_batch(self, examples): - examples = [eg for eg in examples if self.has_gold(eg)] states = self.init_batch([eg.predicted for eg in examples]) - keeps = [i for i, s in enumerate(states) if not s.is_final()] + keeps = [i for i, (eg, s) in enumerate(zip(examples, states)) + if self.has_gold(eg) and not s.is_final()] golds = [ArcEagerGold(self, states[i], examples[i]) for i in keeps] + examples = [examples[i] for i in keeps] states = [states[i] for i in keeps] for gold in golds: self._replace_unseen_labels(gold) n_steps = sum([len(s.queue) * 4 for s in states]) - return states, golds, n_steps + return states, golds, examples, n_steps def _replace_unseen_labels(self, ArcEagerGold gold): backoff_label = self.strings["dep"] diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 3fc64cb82..ed610bd81 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -130,13 +130,13 @@ cdef class BiluoPushDown(TransitionSystem): return MOVE_NAMES[move] + '-' + self.strings[label] def init_gold_batch(self, examples): - examples = [eg for eg in examples if self.has_gold(eg)] states = self.init_batch([eg.predicted for eg in examples]) - keeps = [i for i, s in enumerate(states) if not s.is_final()] + keeps = [i for i, (s, eg) in enumerate(zip(states, examples)) + if not s.is_final() and self.has_gold(eg)] golds = [BiluoGold(self, states[i], examples[i]) for i in keeps] states = [states[i] for i in keeps] n_steps = sum([len(s.queue) for s in states]) - return states, golds, n_steps + return states, golds, examples, n_steps cdef Transition lookup_transition(self, object name) except *: cdef attr_t label @@ -262,11 +262,11 @@ cdef class BiluoPushDown(TransitionSystem): n_gold = 0 for i in range(self.n_moves): if self.c[i].is_valid(stcls.c, self.c[i].label): - is_valid[i] = True + is_valid[i] = 1 costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label) n_gold += costs[i] <= 0 else: - is_valid[i] = False + is_valid[i] = 0 costs[i] = 9000 if n_gold < 1: raise ValueError