mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-05 13:43:24 +03:00
positive_label config for textcat (#6062)
* hook up positive_label in textcat * unit tests * documentation * formatting * tests * fix typo * move verify_config to after begin_training * revert accidential commit
This commit is contained in:
parent
b854e0bef9
commit
3216a33149
|
@ -89,7 +89,6 @@ def train(
|
||||||
nlp, config = util.load_model_from_config(config)
|
nlp, config = util.load_model_from_config(config)
|
||||||
if config["training"]["vectors"] is not None:
|
if config["training"]["vectors"] is not None:
|
||||||
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||||
verify_config(nlp)
|
|
||||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||||
T_cfg = config["training"]
|
T_cfg = config["training"]
|
||||||
optimizer = T_cfg["optimizer"]
|
optimizer = T_cfg["optimizer"]
|
||||||
|
@ -108,6 +107,8 @@ def train(
|
||||||
nlp.resume_training(sgd=optimizer)
|
nlp.resume_training(sgd=optimizer)
|
||||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
||||||
|
# Verify the config after calling 'begin_training' to ensure labels are properly initialized
|
||||||
|
verify_config(nlp)
|
||||||
|
|
||||||
if tag_map:
|
if tag_map:
|
||||||
# Replace tag map with provided mapping
|
# Replace tag map with provided mapping
|
||||||
|
@ -401,7 +402,7 @@ def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> No
|
||||||
|
|
||||||
|
|
||||||
def verify_config(nlp: Language) -> None:
|
def verify_config(nlp: Language) -> None:
|
||||||
"""Perform additional checks based on the config and loaded nlp object."""
|
"""Perform additional checks based on the config, loaded nlp object and training data."""
|
||||||
# TODO: maybe we should validate based on the actual components, the list
|
# TODO: maybe we should validate based on the actual components, the list
|
||||||
# in config["nlp"]["pipeline"] instead?
|
# in config["nlp"]["pipeline"] instead?
|
||||||
for pipe_config in nlp.config["components"].values():
|
for pipe_config in nlp.config["components"].values():
|
||||||
|
@ -415,18 +416,13 @@ def verify_textcat_config(nlp: Language, pipe_config: Dict[str, Any]) -> None:
|
||||||
# if 'positive_label' is provided: double check whether it's in the data and
|
# if 'positive_label' is provided: double check whether it's in the data and
|
||||||
# the task is binary
|
# the task is binary
|
||||||
if pipe_config.get("positive_label"):
|
if pipe_config.get("positive_label"):
|
||||||
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
textcat_labels = nlp.get_pipe("textcat").labels
|
||||||
pos_label = pipe_config.get("positive_label")
|
pos_label = pipe_config.get("positive_label")
|
||||||
if pos_label not in textcat_labels:
|
if pos_label not in textcat_labels:
|
||||||
msg.fail(
|
raise ValueError(
|
||||||
f"The textcat's 'positive_label' config setting '{pos_label}' "
|
Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
|
||||||
f"does not match any label in the training data.",
|
|
||||||
exits=1,
|
|
||||||
)
|
)
|
||||||
if len(textcat_labels) != 2:
|
if len(list(textcat_labels)) != 2:
|
||||||
msg.fail(
|
raise ValueError(
|
||||||
f"A textcat 'positive_label' '{pos_label}' was "
|
Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
|
||||||
f"provided for training data that does not appear to be a "
|
|
||||||
f"binary classification problem with two labels.",
|
|
||||||
exits=1,
|
|
||||||
)
|
)
|
||||||
|
|
|
@ -480,6 +480,11 @@ class Errors:
|
||||||
E201 = ("Span index out of range.")
|
E201 = ("Span index out of range.")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E919 = ("A textcat 'positive_label' '{pos_label}' was provided for training "
|
||||||
|
"data that does not appear to be a binary classification problem "
|
||||||
|
"with two labels. Labels found: {labels}")
|
||||||
|
E920 = ("The textcat's 'positive_label' config setting '{pos_label}' "
|
||||||
|
"does not match any label in the training data. Labels found: {labels}")
|
||||||
E921 = ("The method 'set_output' can only be called on components that have "
|
E921 = ("The method 'set_output' can only be called on components that have "
|
||||||
"a Model with a 'resize_output' attribute. Otherwise, the output "
|
"a Model with a 'resize_output' attribute. Otherwise, the output "
|
||||||
"layer can not be dynamically changed.")
|
"layer can not be dynamically changed.")
|
||||||
|
|
|
@ -56,7 +56,12 @@ subword_features = true
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"textcat",
|
"textcat",
|
||||||
assigns=["doc.cats"],
|
assigns=["doc.cats"],
|
||||||
default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
|
default_config={
|
||||||
|
"labels": [],
|
||||||
|
"threshold": 0.5,
|
||||||
|
"positive_label": None,
|
||||||
|
"model": DEFAULT_TEXTCAT_MODEL,
|
||||||
|
},
|
||||||
scores=[
|
scores=[
|
||||||
"cats_score",
|
"cats_score",
|
||||||
"cats_score_desc",
|
"cats_score_desc",
|
||||||
|
@ -74,8 +79,9 @@ def make_textcat(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
name: str,
|
name: str,
|
||||||
model: Model[List[Doc], List[Floats2d]],
|
model: Model[List[Doc], List[Floats2d]],
|
||||||
labels: Iterable[str],
|
labels: List[str],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
|
positive_label: Optional[str],
|
||||||
) -> "TextCategorizer":
|
) -> "TextCategorizer":
|
||||||
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
||||||
over a whole document. It can learn one or more labels, and the labels can
|
over a whole document. It can learn one or more labels, and the labels can
|
||||||
|
@ -88,8 +94,16 @@ def make_textcat(
|
||||||
labels (list): A list of categories to learn. If empty, the model infers the
|
labels (list): A list of categories to learn. If empty, the model infers the
|
||||||
categories from the data.
|
categories from the data.
|
||||||
threshold (float): Cutoff to consider a prediction "positive".
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
|
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
|
||||||
"""
|
"""
|
||||||
return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold)
|
return TextCategorizer(
|
||||||
|
nlp.vocab,
|
||||||
|
model,
|
||||||
|
name,
|
||||||
|
labels=labels,
|
||||||
|
threshold=threshold,
|
||||||
|
positive_label=positive_label,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TextCategorizer(Pipe):
|
class TextCategorizer(Pipe):
|
||||||
|
@ -104,8 +118,9 @@ class TextCategorizer(Pipe):
|
||||||
model: Model,
|
model: Model,
|
||||||
name: str = "textcat",
|
name: str = "textcat",
|
||||||
*,
|
*,
|
||||||
labels: Iterable[str],
|
labels: List[str],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
|
positive_label: Optional[str],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer.
|
"""Initialize a text categorizer.
|
||||||
|
|
||||||
|
@ -113,8 +128,9 @@ class TextCategorizer(Pipe):
|
||||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
name (str): The component instance name, used to add entries to the
|
name (str): The component instance name, used to add entries to the
|
||||||
losses during training.
|
losses during training.
|
||||||
labels (Iterable[str]): The labels to use.
|
labels (List[str]): The labels to use.
|
||||||
threshold (float): Cutoff to consider a prediction "positive".
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
|
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#init
|
DOCS: https://nightly.spacy.io/api/textcategorizer#init
|
||||||
"""
|
"""
|
||||||
|
@ -122,7 +138,11 @@ class TextCategorizer(Pipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
cfg = {"labels": labels, "threshold": threshold}
|
cfg = {
|
||||||
|
"labels": labels,
|
||||||
|
"threshold": threshold,
|
||||||
|
"positive_label": positive_label,
|
||||||
|
}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -131,10 +151,10 @@ class TextCategorizer(Pipe):
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#labels
|
DOCS: https://nightly.spacy.io/api/textcategorizer#labels
|
||||||
"""
|
"""
|
||||||
return tuple(self.cfg.setdefault("labels", []))
|
return tuple(self.cfg["labels"])
|
||||||
|
|
||||||
@labels.setter
|
@labels.setter
|
||||||
def labels(self, value: Iterable[str]) -> None:
|
def labels(self, value: List[str]) -> None:
|
||||||
self.cfg["labels"] = tuple(value)
|
self.cfg["labels"] = tuple(value)
|
||||||
|
|
||||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||||
|
@ -353,17 +373,10 @@ class TextCategorizer(Pipe):
|
||||||
sgd = self.create_optimizer()
|
sgd = self.create_optimizer()
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
def score(
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||||
self,
|
|
||||||
examples: Iterable[Example],
|
|
||||||
*,
|
|
||||||
positive_label: Optional[str] = None,
|
|
||||||
**kwargs,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Score a batch of examples.
|
"""Score a batch of examples.
|
||||||
|
|
||||||
examples (Iterable[Example]): The examples to score.
|
examples (Iterable[Example]): The examples to score.
|
||||||
positive_label (str): Optional positive label.
|
|
||||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
|
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#score
|
DOCS: https://nightly.spacy.io/api/textcategorizer#score
|
||||||
|
@ -374,7 +387,7 @@ class TextCategorizer(Pipe):
|
||||||
"cats",
|
"cats",
|
||||||
labels=self.labels,
|
labels=self.labels,
|
||||||
multi_label=self.model.attrs["multi_label"],
|
multi_label=self.model.attrs["multi_label"],
|
||||||
positive_label=positive_label,
|
positive_label=self.cfg["positive_label"],
|
||||||
threshold=self.cfg["threshold"],
|
threshold=self.cfg["threshold"],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,6 +10,7 @@ from spacy.tokens import Doc
|
||||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
from ...cli.train import verify_textcat_config
|
||||||
from ...training import Example
|
from ...training import Example
|
||||||
|
|
||||||
|
|
||||||
|
@ -130,7 +131,10 @@ def test_overfitting_IO():
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
# Set exclusive labels
|
# Set exclusive labels
|
||||||
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}})
|
textcat = nlp.add_pipe(
|
||||||
|
"textcat",
|
||||||
|
config={"model": {"exclusive_classes": True}, "positive_label": "POSITIVE"},
|
||||||
|
)
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
@ -159,7 +163,7 @@ def test_overfitting_IO():
|
||||||
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
|
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
|
||||||
|
|
||||||
# Test scoring
|
# Test scoring
|
||||||
scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})
|
scores = nlp.evaluate(train_examples)
|
||||||
assert scores["cats_micro_f"] == 1.0
|
assert scores["cats_micro_f"] == 1.0
|
||||||
assert scores["cats_score"] == 1.0
|
assert scores["cats_score"] == 1.0
|
||||||
assert "cats_score_desc" in scores
|
assert "cats_score_desc" in scores
|
||||||
|
@ -194,3 +198,29 @@ def test_textcat_configs(textcat_config):
|
||||||
for i in range(5):
|
for i in range(5):
|
||||||
losses = {}
|
losses = {}
|
||||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
|
||||||
|
def test_positive_class():
|
||||||
|
nlp = English()
|
||||||
|
pipe_config = {"positive_label": "POS", "labels": ["POS", "NEG"]}
|
||||||
|
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||||
|
assert textcat.labels == ("POS", "NEG")
|
||||||
|
verify_textcat_config(nlp, pipe_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_positive_class_not_present():
|
||||||
|
nlp = English()
|
||||||
|
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING"]}
|
||||||
|
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||||
|
assert textcat.labels == ("SOME", "THING")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
verify_textcat_config(nlp, pipe_config)
|
||||||
|
|
||||||
|
|
||||||
|
def test_positive_class_not_binary():
|
||||||
|
nlp = English()
|
||||||
|
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING", "POS"]}
|
||||||
|
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||||
|
assert textcat.labels == ("SOME", "THING", "POS")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
verify_textcat_config(nlp, pipe_config)
|
||||||
|
|
|
@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
|
||||||
# See issue #1105
|
# See issue #1105
|
||||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5)
|
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5, positive_label=None)
|
||||||
textcat.to_bytes(exclude=["vocab"])
|
textcat.to_bytes(exclude=["vocab"])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -37,9 +37,10 @@ architectures and their arguments and hyperparameters.
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Setting | Description |
|
| Setting | Description |
|
||||||
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
|
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||||
|
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise and by default. ~~Optional[str]~~ |
|
||||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||||
|
|
||||||
```python
|
```python
|
||||||
|
@ -60,7 +61,7 @@ architectures and their arguments and hyperparameters.
|
||||||
>
|
>
|
||||||
> # Construction from class
|
> # Construction from class
|
||||||
> from spacy.pipeline import TextCategorizer
|
> from spacy.pipeline import TextCategorizer
|
||||||
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5)
|
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5, positive_label="POS")
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
Create a new pipeline instance. In your application, you would normally use a
|
Create a new pipeline instance. In your application, you would normally use a
|
||||||
|
@ -68,13 +69,14 @@ shortcut for this and instantiate the component using its string name and
|
||||||
[`nlp.add_pipe`](/api/language#create_pipe).
|
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
| ---------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||||
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
|
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `labels` | The labels to use. ~~Iterable[str]~~ |
|
| `labels` | The labels to use. ~~Iterable[str]~~ |
|
||||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||||
|
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise. ~~Optional[str]~~ |
|
||||||
|
|
||||||
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user