mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Fix tagger training when some tags are missing
This commit is contained in:
parent
62eec33bc4
commit
164d90878e
|
@ -192,6 +192,9 @@ class Tagger(Pipe):
|
||||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
return
|
return
|
||||||
|
if not any(eg.reference.is_tagged for eg in examples):
|
||||||
|
# Handle cases where there are no tagged tokens in any docs.
|
||||||
|
return
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
|
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||||
for sc in tag_scores:
|
for sc in tag_scores:
|
||||||
|
@ -251,7 +254,11 @@ class Tagger(Pipe):
|
||||||
DOCS: https://nightly.spacy.io/api/tagger#get_loss
|
DOCS: https://nightly.spacy.io/api/tagger#get_loss
|
||||||
"""
|
"""
|
||||||
validate_examples(examples, "Tagger.get_loss")
|
validate_examples(examples, "Tagger.get_loss")
|
||||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
loss_func = SequenceCategoricalCrossentropy(
|
||||||
|
names=self.label,
|
||||||
|
normalize=False,
|
||||||
|
missing_value=""
|
||||||
|
)
|
||||||
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
|
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
|
||||||
d_scores, loss = loss_func(scores, truths)
|
d_scores, loss = loss_func(scores, truths)
|
||||||
if self.model.ops.xp.isnan(loss):
|
if self.model.ops.xp.isnan(loss):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user