positive_label config for textcat (#6062)

* hook up positive_label in textcat

* unit tests

* documentation

* formatting

* tests

* fix typo

* move verify_config to after begin_training

* revert accidential commit
This commit is contained in:
Sofie Van Landeghem 2020-09-14 17:08:00 +02:00 committed by GitHub
parent b854e0bef9
commit 3216a33149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 93 additions and 47 deletions

View File

@ -89,7 +89,6 @@ def train(
nlp, config = util.load_model_from_config(config)
if config["training"]["vectors"] is not None:
util.load_vectors_into_model(nlp, config["training"]["vectors"])
verify_config(nlp)
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
T_cfg = config["training"]
optimizer = T_cfg["optimizer"]
@ -108,6 +107,8 @@ def train(
nlp.resume_training(sgd=optimizer)
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
# Verify the config after calling 'begin_training' to ensure labels are properly initialized
verify_config(nlp)
if tag_map:
# Replace tag map with provided mapping
@ -401,7 +402,7 @@ def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> No
def verify_config(nlp: Language) -> None:
"""Perform additional checks based on the config and loaded nlp object."""
"""Perform additional checks based on the config, loaded nlp object and training data."""
# TODO: maybe we should validate based on the actual components, the list
# in config["nlp"]["pipeline"] instead?
for pipe_config in nlp.config["components"].values():
@ -415,18 +416,13 @@ def verify_textcat_config(nlp: Language, pipe_config: Dict[str, Any]) -> None:
# if 'positive_label' is provided: double check whether it's in the data and
# the task is binary
if pipe_config.get("positive_label"):
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
textcat_labels = nlp.get_pipe("textcat").labels
pos_label = pipe_config.get("positive_label")
if pos_label not in textcat_labels:
msg.fail(
f"The textcat's 'positive_label' config setting '{pos_label}' "
f"does not match any label in the training data.",
exits=1,
raise ValueError(
Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
)
if len(textcat_labels) != 2:
msg.fail(
f"A textcat 'positive_label' '{pos_label}' was "
f"provided for training data that does not appear to be a "
f"binary classification problem with two labels.",
exits=1,
if len(list(textcat_labels)) != 2:
raise ValueError(
Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
)

View File

@ -480,6 +480,11 @@ class Errors:
E201 = ("Span index out of range.")
# TODO: fix numbering after merging develop into master
E919 = ("A textcat 'positive_label' '{pos_label}' was provided for training "
"data that does not appear to be a binary classification problem "
"with two labels. Labels found: {labels}")
E920 = ("The textcat's 'positive_label' config setting '{pos_label}' "
"does not match any label in the training data. Labels found: {labels}")
E921 = ("The method 'set_output' can only be called on components that have "
"a Model with a 'resize_output' attribute. Otherwise, the output "
"layer can not be dynamically changed.")

View File

@ -56,7 +56,12 @@ subword_features = true
@Language.factory(
"textcat",
assigns=["doc.cats"],
default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
default_config={
"labels": [],
"threshold": 0.5,
"positive_label": None,
"model": DEFAULT_TEXTCAT_MODEL,
},
scores=[
"cats_score",
"cats_score_desc",
@ -74,8 +79,9 @@ def make_textcat(
nlp: Language,
name: str,
model: Model[List[Doc], List[Floats2d]],
labels: Iterable[str],
labels: List[str],
threshold: float,
positive_label: Optional[str],
) -> "TextCategorizer":
"""Create a TextCategorizer compoment. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels can
@ -88,8 +94,16 @@ def make_textcat(
labels (list): A list of categories to learn. If empty, the model infers the
categories from the data.
threshold (float): Cutoff to consider a prediction "positive".
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
"""
return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold)
return TextCategorizer(
nlp.vocab,
model,
name,
labels=labels,
threshold=threshold,
positive_label=positive_label,
)
class TextCategorizer(Pipe):
@ -104,8 +118,9 @@ class TextCategorizer(Pipe):
model: Model,
name: str = "textcat",
*,
labels: Iterable[str],
labels: List[str],
threshold: float,
positive_label: Optional[str],
) -> None:
"""Initialize a text categorizer.
@ -113,8 +128,9 @@ class TextCategorizer(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
labels (Iterable[str]): The labels to use.
labels (List[str]): The labels to use.
threshold (float): Cutoff to consider a prediction "positive".
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
DOCS: https://nightly.spacy.io/api/textcategorizer#init
"""
@ -122,7 +138,11 @@ class TextCategorizer(Pipe):
self.model = model
self.name = name
self._rehearsal_model = None
cfg = {"labels": labels, "threshold": threshold}
cfg = {
"labels": labels,
"threshold": threshold,
"positive_label": positive_label,
}
self.cfg = dict(cfg)
@property
@ -131,10 +151,10 @@ class TextCategorizer(Pipe):
DOCS: https://nightly.spacy.io/api/textcategorizer#labels
"""
return tuple(self.cfg.setdefault("labels", []))
return tuple(self.cfg["labels"])
@labels.setter
def labels(self, value: Iterable[str]) -> None:
def labels(self, value: List[str]) -> None:
self.cfg["labels"] = tuple(value)
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
@ -353,17 +373,10 @@ class TextCategorizer(Pipe):
sgd = self.create_optimizer()
return sgd
def score(
self,
examples: Iterable[Example],
*,
positive_label: Optional[str] = None,
**kwargs,
) -> Dict[str, Any]:
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples.
examples (Iterable[Example]): The examples to score.
positive_label (str): Optional positive label.
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
DOCS: https://nightly.spacy.io/api/textcategorizer#score
@ -374,7 +387,7 @@ class TextCategorizer(Pipe):
"cats",
labels=self.labels,
multi_label=self.model.attrs["multi_label"],
positive_label=positive_label,
positive_label=self.cfg["positive_label"],
threshold=self.cfg["threshold"],
**kwargs,
)

View File

@ -10,6 +10,7 @@ from spacy.tokens import Doc
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from ..util import make_tempdir
from ...cli.train import verify_textcat_config
from ...training import Example
@ -130,7 +131,10 @@ def test_overfitting_IO():
fix_random_seed(0)
nlp = English()
# Set exclusive labels
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}})
textcat = nlp.add_pipe(
"textcat",
config={"model": {"exclusive_classes": True}, "positive_label": "POSITIVE"},
)
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@ -159,7 +163,7 @@ def test_overfitting_IO():
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
# Test scoring
scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})
scores = nlp.evaluate(train_examples)
assert scores["cats_micro_f"] == 1.0
assert scores["cats_score"] == 1.0
assert "cats_score_desc" in scores
@ -194,3 +198,29 @@ def test_textcat_configs(textcat_config):
for i in range(5):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
def test_positive_class():
nlp = English()
pipe_config = {"positive_label": "POS", "labels": ["POS", "NEG"]}
textcat = nlp.add_pipe("textcat", config=pipe_config)
assert textcat.labels == ("POS", "NEG")
verify_textcat_config(nlp, pipe_config)
def test_positive_class_not_present():
nlp = English()
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING"]}
textcat = nlp.add_pipe("textcat", config=pipe_config)
assert textcat.labels == ("SOME", "THING")
with pytest.raises(ValueError):
verify_textcat_config(nlp, pipe_config)
def test_positive_class_not_binary():
nlp = English()
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING", "POS"]}
textcat = nlp.add_pipe("textcat", config=pipe_config)
assert textcat.labels == ("SOME", "THING", "POS")
with pytest.raises(ValueError):
verify_textcat_config(nlp, pipe_config)

View File

@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
# See issue #1105
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5)
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5, positive_label=None)
textcat.to_bytes(exclude=["vocab"])

View File

@ -36,11 +36,12 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("textcat", config=config)
> ```
| Setting | Description |
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
| Setting | Description |
| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise and by default. ~~Optional[str]~~ |
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
```python
%%GITHUB_SPACY/spacy/pipeline/textcat.py
@ -60,21 +61,22 @@ architectures and their arguments and hyperparameters.
>
> # Construction from class
> from spacy.pipeline import TextCategorizer
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5)
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5, positive_label="POS")
> ```
Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#create_pipe).
| Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | The shared vocabulary. ~~Vocab~~ |
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | |
| `labels` | The labels to use. ~~Iterable[str]~~ |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
| Name | Description |
| ---------------- | -------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | The shared vocabulary. ~~Vocab~~ |
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | |
| `labels` | The labels to use. ~~Iterable[str]~~ |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise. ~~Optional[str]~~ |
## TextCategorizer.\_\_call\_\_ {#call tag="method"}