Make the Tagger neg_prefix configurable (#9802)

This commit is contained in:
Adriane Boyd 2021-12-06 18:04:44 +01:00 committed by GitHub
parent b56b9e7f31
commit 9964243eb2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -45,7 +45,7 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"tagger", "tagger",
assigns=["token.tag"], assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}}, default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
default_score_weights={"tag_acc": 1.0}, default_score_weights={"tag_acc": 1.0},
) )
def make_tagger( def make_tagger(
@ -54,6 +54,7 @@ def make_tagger(
model: Model, model: Model,
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
neg_prefix: str,
): ):
"""Construct a part-of-speech tagger component. """Construct a part-of-speech tagger component.
@ -62,7 +63,7 @@ def make_tagger(
in size, and be normalized as probabilities (all scores between 0 and 1, in size, and be normalized as probabilities (all scores between 0 and 1,
with the rows summing to 1). with the rows summing to 1).
""" """
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer) return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
def tagger_score(examples, **kwargs): def tagger_score(examples, **kwargs):
@ -87,6 +88,7 @@ class Tagger(TrainablePipe):
*, *,
overwrite=BACKWARD_OVERWRITE, overwrite=BACKWARD_OVERWRITE,
scorer=tagger_score, scorer=tagger_score,
neg_prefix="!",
): ):
"""Initialize a part-of-speech tagger. """Initialize a part-of-speech tagger.
@ -103,7 +105,7 @@ class Tagger(TrainablePipe):
self.model = model self.model = model
self.name = name self.name = name
self._rehearsal_model = None self._rehearsal_model = None
cfg = {"labels": [], "overwrite": overwrite} cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer self.scorer = scorer
@ -253,7 +255,7 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#get_loss DOCS: https://spacy.io/api/tagger#get_loss
""" """
validate_examples(examples, "Tagger.get_loss") validate_examples(examples, "Tagger.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix="!") loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"])
# Convert empty tag "" to missing value None so that both misaligned # Convert empty tag "" to missing value None so that both misaligned
# tokens and tokens with missing annotation have the default missing # tokens and tokens with missing annotation have the default missing
# value None. # value None.