mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Make the Tagger neg_prefix configurable (#9802)
This commit is contained in:
parent
b56b9e7f31
commit
9964243eb2
|
@ -45,7 +45,7 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"tagger",
|
||||
assigns=["token.tag"],
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}},
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
|
||||
default_score_weights={"tag_acc": 1.0},
|
||||
)
|
||||
def make_tagger(
|
||||
|
@ -54,6 +54,7 @@ def make_tagger(
|
|||
model: Model,
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
neg_prefix: str,
|
||||
):
|
||||
"""Construct a part-of-speech tagger component.
|
||||
|
||||
|
@ -62,7 +63,7 @@ def make_tagger(
|
|||
in size, and be normalized as probabilities (all scores between 0 and 1,
|
||||
with the rows summing to 1).
|
||||
"""
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
|
||||
|
||||
|
||||
def tagger_score(examples, **kwargs):
|
||||
|
@ -87,6 +88,7 @@ class Tagger(TrainablePipe):
|
|||
*,
|
||||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=tagger_score,
|
||||
neg_prefix="!",
|
||||
):
|
||||
"""Initialize a part-of-speech tagger.
|
||||
|
||||
|
@ -103,7 +105,7 @@ class Tagger(TrainablePipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {"labels": [], "overwrite": overwrite}
|
||||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
|
||||
|
@ -253,7 +255,7 @@ class Tagger(TrainablePipe):
|
|||
DOCS: https://spacy.io/api/tagger#get_loss
|
||||
"""
|
||||
validate_examples(examples, "Tagger.get_loss")
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix="!")
|
||||
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"])
|
||||
# Convert empty tag "" to missing value None so that both misaligned
|
||||
# tokens and tokens with missing annotation have the default missing
|
||||
# value None.
|
||||
|
|
Loading…
Reference in New Issue
Block a user