mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix multitask objectives
This commit is contained in:
parent
d1246c95fb
commit
8f06903e09
|
@ -681,13 +681,19 @@ class MultitaskObjective(Tagger):
|
|||
return tokvecs, scores
|
||||
|
||||
def get_loss(self, docs, golds, scores):
|
||||
assert len(docs) == len(golds)
|
||||
cdef int idx = 0
|
||||
correct = numpy.zeros((scores.shape[0],), dtype='i')
|
||||
guesses = scores.argmax(axis=1)
|
||||
for gold in golds:
|
||||
for i in range(len(gold.labels)):
|
||||
label = self.make_label(i, gold.words, gold.tags, gold.heads,
|
||||
gold.labels, gold.ents)
|
||||
for i, gold in enumerate(golds):
|
||||
for j in range(len(docs[i])):
|
||||
# Handes alignment for tokenization differences
|
||||
gold_idx = gold.cand_to_gold[j]
|
||||
if gold_idx is None:
|
||||
idx += 1
|
||||
continue
|
||||
label = self.make_label(gold_idx, gold.words, gold.tags,
|
||||
gold.heads, gold.labels, gold.ents)
|
||||
if label is None or label not in self.labels:
|
||||
correct[idx] = guesses[idx]
|
||||
else:
|
||||
|
|
|
@ -542,6 +542,7 @@ cdef class Parser:
|
|||
def update(self, docs, golds, drop=0., sgd=None, losses=None):
|
||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||
return None
|
||||
assert len(docs) == len(golds)
|
||||
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() >= 0.0:
|
||||
return self.update_beam(docs, golds,
|
||||
self.cfg['beam_width'], self.cfg['beam_density'],
|
||||
|
@ -551,6 +552,8 @@ cdef class Parser:
|
|||
if isinstance(docs, Doc) and isinstance(golds, GoldParse):
|
||||
docs = [docs]
|
||||
golds = [golds]
|
||||
for multitask in self._multitasks:
|
||||
multitask.update(docs, golds, drop=drop, sgd=sgd)
|
||||
cuda_stream = util.get_cuda_stream()
|
||||
states, golds, max_steps = self._init_gold_batch(docs, golds)
|
||||
(tokvecs, bp_tokvecs), state2vec, vec2scores = self.get_batch_model(docs, cuda_stream,
|
||||
|
@ -605,9 +608,7 @@ cdef class Parser:
|
|||
break
|
||||
self._make_updates(d_tokvecs,
|
||||
bp_tokvecs, backprops, sgd, cuda_stream)
|
||||
for multitask in self._multitasks:
|
||||
multitask.update(docs, golds, drop=drop, sgd=sgd)
|
||||
|
||||
|
||||
def update_beam(self, docs, golds, width=None, density=None,
|
||||
drop=0., sgd=None, losses=None):
|
||||
if not any(self.moves.has_gold(gold) for gold in golds):
|
||||
|
|
Loading…
Reference in New Issue
Block a user