Tagger label smoothing (#12293)

* add label smoothing

* use True/False instead of floats

* add entropy to debug data

* formatting

* docs

* change test to check difference in distributions

* Update website/docs/api/tagger.mdx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update spacy/pipeline/tagger.pyx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* bool -> float

* update docs

* fix seed

* black

* update tests to use label_smoothing = 0.0

* set default to 0.0, update quickstart

* Update spacy/pipeline/tagger.pyx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* update morphologizer, tagger test

* fix morph docs

* add url to docs

---------

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Vinit Ravishankar 2023-03-22 12:17:56 +01:00 committed by GitHub
parent b479f8bfa5
commit 28de85737f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 87 additions and 23 deletions

View File

@ -7,6 +7,7 @@ import srsly
from wasabi import Printer, MESSAGES, msg
import typer
import math
import numpy
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli, _format_number
@ -521,9 +522,13 @@ def debug_data(
if "tagger" in factory_names:
msg.divider("Part-of-speech Tagging")
label_list = [label for label in gold_train_data["tags"]]
model_labels = _get_labels_from_model(nlp, "tagger")
label_list, counts = zip(*gold_train_data["tags"].items())
msg.info(f"{len(label_list)} label(s) in train data")
p = numpy.array(counts)
p = p / p.sum()
norm_entropy = (-p * numpy.log2(p)).sum() / numpy.log2(len(label_list))
msg.info(f"{norm_entropy} is the normalised label entropy")
model_labels = _get_labels_from_model(nlp, "tagger")
labels = set(label_list)
missing_labels = model_labels - labels
if missing_labels:

View File

@ -331,6 +331,7 @@ maxout_pieces = 3
{% if "morphologizer" in components %}
[components.morphologizer]
factory = "morphologizer"
label_smoothing = 0.05
[components.morphologizer.model]
@architectures = "spacy.Tagger.v2"
@ -344,6 +345,7 @@ width = ${components.tok2vec.model.encode.width}
{% if "tagger" in components %}
[components.tagger]
factory = "tagger"
label_smoothing = 0.05
[components.tagger.model]
@architectures = "spacy.Tagger.v2"

View File

@ -52,7 +52,8 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory(
"morphologizer",
assigns=["token.morph", "token.pos"],
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False, "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}},
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False,
"scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}, "label_smoothing": 0.0},
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
)
def make_morphologizer(
@ -61,9 +62,10 @@ def make_morphologizer(
name: str,
overwrite: bool,
extend: bool,
label_smoothing: float,
scorer: Optional[Callable],
):
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer)
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, label_smoothing=label_smoothing, scorer=scorer)
def morphologizer_score(examples, **kwargs):
@ -94,6 +96,7 @@ class Morphologizer(Tagger):
*,
overwrite: bool = BACKWARD_OVERWRITE,
extend: bool = BACKWARD_EXTEND,
label_smoothing: float = 0.0,
scorer: Optional[Callable] = morphologizer_score,
):
"""Initialize a morphologizer.
@ -121,6 +124,7 @@ class Morphologizer(Tagger):
"labels_pos": {},
"overwrite": overwrite,
"extend": extend,
"label_smoothing": label_smoothing,
}
self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer
@ -270,7 +274,8 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#get_loss
"""
validate_examples(examples, "Morphologizer.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False,
label_smoothing=self.cfg["label_smoothing"])
truths = []
for eg in examples:
eg_truths = []

View File

@ -45,7 +45,7 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory(
"tagger",
assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!", "label_smoothing": 0.0},
default_score_weights={"tag_acc": 1.0},
)
def make_tagger(
@ -55,6 +55,7 @@ def make_tagger(
overwrite: bool,
scorer: Optional[Callable],
neg_prefix: str,
label_smoothing: float,
):
"""Construct a part-of-speech tagger component.
@ -63,7 +64,7 @@ def make_tagger(
in size, and be normalized as probabilities (all scores between 0 and 1,
with the rows summing to 1).
"""
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, label_smoothing=label_smoothing)
def tagger_score(examples, **kwargs):
@ -89,6 +90,7 @@ class Tagger(TrainablePipe):
overwrite=BACKWARD_OVERWRITE,
scorer=tagger_score,
neg_prefix="!",
label_smoothing=0.0,
):
"""Initialize a part-of-speech tagger.
@ -105,7 +107,7 @@ class Tagger(TrainablePipe):
self.model = model
self.name = name
self._rehearsal_model = None
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix, "label_smoothing": label_smoothing}
self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer
@ -256,7 +258,7 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#get_loss
"""
validate_examples(examples, "Tagger.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"])
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"], label_smoothing=self.cfg["label_smoothing"])
# Convert empty tag "" to missing value None so that both misaligned
# tokens and tokens with missing annotation have the default missing
# value None.

View File

@ -1,5 +1,5 @@
import pytest
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_almost_equal
from spacy import util
from spacy.training import Example
@ -19,6 +19,8 @@ def test_label_types():
morphologizer.add_label(9)
TAGS = ["Feat=N", "Feat=V", "Feat=J"]
TRAIN_DATA = [
(
"I like green eggs",
@ -32,6 +34,29 @@ TRAIN_DATA = [
]
def test_label_smoothing():
nlp = Language()
morph_no_ls = nlp.add_pipe("morphologizer", "no_label_smoothing")
morph_ls = nlp.add_pipe(
"morphologizer", "label_smoothing", config=dict(label_smoothing=0.05)
)
train_examples = []
losses = {}
for tag in TAGS:
morph_no_ls.add_label(tag)
morph_ls.add_label(tag)
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
tag_scores, bp_tag_scores = morph_ls.model.begin_update(
[eg.predicted for eg in train_examples]
)
no_ls_grads = morph_no_ls.get_loss(train_examples, tag_scores)[1][0]
ls_grads = morph_ls.get_loss(train_examples, tag_scores)[1][0]
assert_almost_equal(ls_grads / no_ls_grads, 0.94285715)
def test_no_label():
nlp = Language()
nlp.add_pipe("morphologizer")

View File

@ -1,5 +1,5 @@
import pytest
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_almost_equal
from spacy.attrs import TAG
from spacy import util
@ -67,6 +67,29 @@ PARTIAL_DATA = [
]
def test_label_smoothing():
nlp = Language()
tagger_no_ls = nlp.add_pipe("tagger", "no_label_smoothing")
tagger_ls = nlp.add_pipe(
"tagger", "label_smoothing", config=dict(label_smoothing=0.05)
)
train_examples = []
losses = {}
for tag in TAGS:
tagger_no_ls.add_label(tag)
tagger_ls.add_label(tag)
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
[eg.predicted for eg in train_examples]
)
no_ls_grads = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
ls_grads = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
assert_almost_equal(ls_grads / no_ls_grads, 0.925)
def test_no_label():
nlp = Language()
nlp.add_pipe("tagger")

View File

@ -42,12 +42,13 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("morphologizer", config=config)
> ```
| Setting | Description |
| ---------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `model` | The model to use. Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether the values of existing features are overwritten. Defaults to `True`. ~~bool~~ |
| `extend` <Tag variant="new">3.2</Tag> | Whether existing feature types (whose values may or may not be overwritten depending on `overwrite`) are preserved. Defaults to `False`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attributes `"pos"` and `"morph"` and [`Scorer.score_token_attr_per_feat`](/api/scorer#score_token_attr_per_feat) for the attribute `"morph"`. ~~Optional[Callable]~~ |
| Setting | Description |
| ---------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `model` | The model to use. Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether the values of existing features are overwritten. Defaults to `True`. ~~bool~~ |
| `extend` <Tag variant="new">3.2</Tag> | Whether existing feature types (whose values may or may not be overwritten depending on `overwrite`) are preserved. Defaults to `False`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attributes `"pos"` and `"morph"` and [`Scorer.score_token_attr_per_feat`](/api/scorer#score_token_attr_per_feat) for the attribute `"morph"`. ~~Optional[Callable]~~ |
| `label_smoothing` <Tag variant="new">3.6</Tag> | [Label smoothing](https://arxiv.org/abs/1906.02629) factor. Defaults to `0.0`. ~~float~~ |
```python
%%GITHUB_SPACY/spacy/pipeline/morphologizer.pyx

View File

@ -40,12 +40,13 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("tagger", config=config)
> ```
| Setting | Description |
| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `model` | A model instance that predicts the tag probabilities. The output vectors should match the number of tags in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"tag"`. ~~Optional[Callable]~~ |
| `neg_prefix` <Tag variant="new">3.2.1</Tag> | The prefix used to specify incorrect tags while training. The tagger will learn not to predict exactly this tag. Defaults to `!`. ~~str~~ |
| Setting | Description |
| ---------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `model` | A model instance that predicts the tag probabilities. The output vectors should match the number of tags in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"tag"`. ~~Optional[Callable]~~ |
| `neg_prefix` <Tag variant="new">3.2.1</Tag> | The prefix used to specify incorrect tags while training. The tagger will learn not to predict exactly this tag. Defaults to `!`. ~~str~~ |
| `label_smoothing` <Tag variant="new">3.6</Tag> | [Label smoothing](https://arxiv.org/abs/1906.02629) factor. Defaults to `0.0`. ~~float~~ |
```python
%%GITHUB_SPACY/spacy/pipeline/tagger.pyx