mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 08:12:24 +03:00
Fix tagger training when some tags are missing
This commit is contained in:
parent
62eec33bc4
commit
164d90878e
|
@ -192,6 +192,9 @@ class Tagger(Pipe):
|
|||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
if not any(eg.reference.is_tagged for eg in examples):
|
||||
# Handle cases where there are no tagged tokens in any docs.
|
||||
return
|
||||
set_dropout_rate(self.model, drop)
|
||||
tag_scores, bp_tag_scores = self.model.begin_update([eg.predicted for eg in examples])
|
||||
for sc in tag_scores:
|
||||
|
@ -251,7 +254,11 @@ class Tagger(Pipe):
|
|||
DOCS: https://nightly.spacy.io/api/tagger#get_loss
|
||||
"""
|
||||
validate_examples(examples, "Tagger.get_loss")
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
|
||||
loss_func = SequenceCategoricalCrossentropy(
|
||||
names=self.label,
|
||||
normalize=False,
|
||||
missing_value=""
|
||||
)
|
||||
truths = [eg.get_aligned("TAG", as_string=True) for eg in examples]
|
||||
d_scores, loss = loss_func(scores, truths)
|
||||
if self.model.ops.xp.isnan(loss):
|
||||
|
|
Loading…
Reference in New Issue
Block a user