diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index f1a0bc91c..24e5841a5 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -483,8 +483,7 @@ cdef class Parser: return beams def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): - docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds) - if not golds: + if not any(self.moves.has_gold(gold) for gold in golds): return None if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5: return self.update_beam(docs_tokvecs, golds, @@ -558,7 +557,8 @@ cdef class Parser: def update_beam(self, docs_tokvecs, golds, width=None, density=None, drop=0., sgd=None, losses=None): - docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds) + if not any(self.moves.has_gold(gold) for gold in golds): + return None if not golds: return None if width is None: @@ -611,15 +611,6 @@ cdef class Parser: bp_my_tokvecs(d_tokvecs, sgd=sgd) return d_tokvecs - def _filter_unlabelled(self, docs_tokvecs, golds): - '''Remove inputs that have no relevant labels before update''' - has_golds = [self.moves.has_gold(gold) for gold in golds] - docs, tokvecs = docs_tokvecs - docs = [docs[i] for i, has_gold in enumerate(has_golds) if has_gold] - tokvecs = [tokvecs[i] for i, has_gold in enumerate(has_golds) if has_gold] - golds = [golds[i] for i, has_gold in enumerate(has_golds) if has_gold] - return (docs, tokvecs), golds - def _init_gold_batch(self, whole_docs, whole_golds): """Make a square batch, of length equal to the shortest doc. A long doc will get multiple states. Let's say we have a doc of length 2*N,