mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-30 01:43:21 +03:00
Ensure updates aren't made if no gold available
This commit is contained in:
parent
42fa84075f
commit
84b7ed49e4
|
@ -406,11 +406,11 @@ cdef class GoldParse:
|
||||||
if tags is None:
|
if tags is None:
|
||||||
tags = [None for _ in doc]
|
tags = [None for _ in doc]
|
||||||
if heads is None:
|
if heads is None:
|
||||||
heads = [token.i for token in doc]
|
heads = [None for token in doc]
|
||||||
if deps is None:
|
if deps is None:
|
||||||
deps = [None for _ in doc]
|
deps = [None for _ in doc]
|
||||||
if entities is None:
|
if entities is None:
|
||||||
entities = ['-' for _ in doc]
|
entities = [None for _ in doc]
|
||||||
elif len(entities) == 0:
|
elif len(entities) == 0:
|
||||||
entities = ['O' for _ in doc]
|
entities = ['O' for _ in doc]
|
||||||
elif not isinstance(entities[0], basestring):
|
elif not isinstance(entities[0], basestring):
|
||||||
|
|
|
@ -113,7 +113,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
|
|
||||||
def has_gold(self, GoldParse gold, start=0, end=None):
|
def has_gold(self, GoldParse gold, start=0, end=None):
|
||||||
end = end or len(gold.ner)
|
end = end or len(gold.ner)
|
||||||
if all([tag == '-' for tag in gold.ner[start:end]]):
|
if all([tag in ('-', None) for tag in gold.ner[start:end]]):
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -483,6 +483,9 @@ cdef class Parser:
|
||||||
return beams
|
return beams
|
||||||
|
|
||||||
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
def update(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
|
||||||
|
docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds)
|
||||||
|
if not golds:
|
||||||
|
return None
|
||||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5:
|
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.5:
|
||||||
return self.update_beam(docs_tokvecs, golds,
|
return self.update_beam(docs_tokvecs, golds,
|
||||||
self.cfg['beam_width'], self.cfg['beam_density'],
|
self.cfg['beam_width'], self.cfg['beam_density'],
|
||||||
|
@ -555,6 +558,9 @@ cdef class Parser:
|
||||||
|
|
||||||
def update_beam(self, docs_tokvecs, golds, width=None, density=None,
|
def update_beam(self, docs_tokvecs, golds, width=None, density=None,
|
||||||
drop=0., sgd=None, losses=None):
|
drop=0., sgd=None, losses=None):
|
||||||
|
docs_tokvecs, golds = self._filter_unlabelled(docs_tokvecs, golds)
|
||||||
|
if not golds:
|
||||||
|
return None
|
||||||
if width is None:
|
if width is None:
|
||||||
width = self.cfg.get('beam_width', 2)
|
width = self.cfg.get('beam_width', 2)
|
||||||
if density is None:
|
if density is None:
|
||||||
|
@ -605,6 +611,15 @@ cdef class Parser:
|
||||||
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
bp_my_tokvecs(d_tokvecs, sgd=sgd)
|
||||||
return d_tokvecs
|
return d_tokvecs
|
||||||
|
|
||||||
|
def _filter_unlabelled(self, docs_tokvecs, golds):
|
||||||
|
'''Remove inputs that have no relevant labels before update'''
|
||||||
|
has_golds = [self.moves.has_gold(gold) for gold in golds]
|
||||||
|
docs, tokvecs = docs_tokvecs
|
||||||
|
docs = [docs[i] for i, has_gold in enumerate(has_golds) if has_gold]
|
||||||
|
tokvecs = [tokvecs[i] for i, has_gold in enumerate(has_golds) if has_gold]
|
||||||
|
golds = [golds[i] for i, has_gold in enumerate(has_golds) if has_gold]
|
||||||
|
return (docs, tokvecs), golds
|
||||||
|
|
||||||
def _init_gold_batch(self, whole_docs, whole_golds):
|
def _init_gold_batch(self, whole_docs, whole_golds):
|
||||||
"""Make a square batch, of length equal to the shortest doc. A long
|
"""Make a square batch, of length equal to the shortest doc. A long
|
||||||
doc will get multiple states. Let's say we have a doc of length 2*N,
|
doc will get multiple states. Let's say we have a doc of length 2*N,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user