mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-14 10:30:34 +03:00
fix tagger
This commit is contained in:
parent
10d396977e
commit
be5934b827
|
@ -343,7 +343,7 @@ class Tagger(Pipe):
|
||||||
|
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels)
|
loss_func = SequenceCategoricalCrossentropy(names=self.labels)
|
||||||
truths = [eg.get_aligned("tag") for eg in examples]
|
truths = [eg.get_aligned("tag", as_string=True) for eg in examples]
|
||||||
d_scores, loss = loss_func(scores, truths)
|
d_scores, loss = loss_func(scores, truths)
|
||||||
if self.model.ops.xp.isnan(loss):
|
if self.model.ops.xp.isnan(loss):
|
||||||
raise ValueError("nan value when computing loss")
|
raise ValueError("nan value when computing loss")
|
||||||
|
@ -679,7 +679,7 @@ class MultitaskObjective(Tagger):
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
for i, eg in enumerate(examples):
|
for i, eg in enumerate(examples):
|
||||||
# Handles alignment for tokenization differences
|
# Handles alignment for tokenization differences
|
||||||
doc_annots = eg.get_aligned()
|
doc_annots = eg.get_aligned() # TODO
|
||||||
for j in range(len(eg.predicted)):
|
for j in range(len(eg.predicted)):
|
||||||
tok_annots = {key: values[j] for key, values in tok_annots.items()}
|
tok_annots = {key: values[j] for key, values in tok_annots.items()}
|
||||||
label = self.make_label(j, tok_annots)
|
label = self.make_label(j, tok_annots)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user