Fix misalignment caued by filtering inputs at wrong point in parser

This commit is contained in:
Matthew Honnibal 2017-08-20 15:59:28 -05:00
parent 78a5f842e9
commit 62878e50db

View File

@ -483,8 +483,7 @@ cdef class Parser:
return beams return beams
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds) if not any(self.moves.has_gold(gold) for gold in golds):
if not golds:
return None return None
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5: if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5:
return self.update_beam(docs_tokvecs, golds, return self.update_beam(docs_tokvecs, golds,
@ -558,7 +557,8 @@ cdef class Parser:
def update_beam(self, docs_tokvecs, golds, width=None, density=None, def update_beam(self, docs_tokvecs, golds, width=None, density=None,
drop=0., sgd=None, losses=None): drop=0., sgd=None, losses=None):
docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds) if not any(self.moves.has_gold(gold) for gold in golds):
return None
if not golds: if not golds:
return None return None
if width is None: if width is None:
@ -611,15 +611,6 @@ cdef class Parser:
bp_my_tokvecs(d_tokvecs, sgd=sgd) bp_my_tokvecs(d_tokvecs, sgd=sgd)
return d_tokvecs return d_tokvecs
def _filter_unlabelled(self, docs_tokvecs, golds):
'''Remove inputs that have no relevant labels before update'''
has_golds = [self.moves.has_gold(gold) for gold in golds]
docs, tokvecs = docs_tokvecs
docs = [docs[i] for i, has_gold in enumerate(has_golds) if has_gold]
tokvecs = [tokvecs[i] for i, has_gold in enumerate(has_golds) if has_gold]
golds = [golds[i] for i, has_gold in enumerate(has_golds) if has_gold]
return (docs, tokvecs), golds
def _init_gold_batch(self, whole_docs, whole_golds): def _init_gold_batch(self, whole_docs, whole_golds):
"""Make a square batch, of length equal to the shortest doc. A long """Make a square batch, of length equal to the shortest doc. A long
doc will get multiple states. Let's say we have a doc of length 2*N, doc will get multiple states. Let's say we have a doc of length 2*N,