mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
positive_label config for textcat (#6062)
* hook up positive_label in textcat * unit tests * documentation * formatting * tests * fix typo * move verify_config to after begin_training * revert accidential commit
This commit is contained in:
parent
b854e0bef9
commit
3216a33149
|
@ -89,7 +89,6 @@ def train(
|
|||
nlp, config = util.load_model_from_config(config)
|
||||
if config["training"]["vectors"] is not None:
|
||||
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||
verify_config(nlp)
|
||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||
T_cfg = config["training"]
|
||||
optimizer = T_cfg["optimizer"]
|
||||
|
@ -108,6 +107,8 @@ def train(
|
|||
nlp.resume_training(sgd=optimizer)
|
||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
||||
# Verify the config after calling 'begin_training' to ensure labels are properly initialized
|
||||
verify_config(nlp)
|
||||
|
||||
if tag_map:
|
||||
# Replace tag map with provided mapping
|
||||
|
@ -401,7 +402,7 @@ def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> No
|
|||
|
||||
|
||||
def verify_config(nlp: Language) -> None:
|
||||
"""Perform additional checks based on the config and loaded nlp object."""
|
||||
"""Perform additional checks based on the config, loaded nlp object and training data."""
|
||||
# TODO: maybe we should validate based on the actual components, the list
|
||||
# in config["nlp"]["pipeline"] instead?
|
||||
for pipe_config in nlp.config["components"].values():
|
||||
|
@ -415,18 +416,13 @@ def verify_textcat_config(nlp: Language, pipe_config: Dict[str, Any]) -> None:
|
|||
# if 'positive_label' is provided: double check whether it's in the data and
|
||||
# the task is binary
|
||||
if pipe_config.get("positive_label"):
|
||||
textcat_labels = nlp.get_pipe("textcat").cfg.get("labels", [])
|
||||
textcat_labels = nlp.get_pipe("textcat").labels
|
||||
pos_label = pipe_config.get("positive_label")
|
||||
if pos_label not in textcat_labels:
|
||||
msg.fail(
|
||||
f"The textcat's 'positive_label' config setting '{pos_label}' "
|
||||
f"does not match any label in the training data.",
|
||||
exits=1,
|
||||
raise ValueError(
|
||||
Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
|
||||
)
|
||||
if len(textcat_labels) != 2:
|
||||
msg.fail(
|
||||
f"A textcat 'positive_label' '{pos_label}' was "
|
||||
f"provided for training data that does not appear to be a "
|
||||
f"binary classification problem with two labels.",
|
||||
exits=1,
|
||||
if len(list(textcat_labels)) != 2:
|
||||
raise ValueError(
|
||||
Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
|
||||
)
|
||||
|
|
|
@ -480,6 +480,11 @@ class Errors:
|
|||
E201 = ("Span index out of range.")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E919 = ("A textcat 'positive_label' '{pos_label}' was provided for training "
|
||||
"data that does not appear to be a binary classification problem "
|
||||
"with two labels. Labels found: {labels}")
|
||||
E920 = ("The textcat's 'positive_label' config setting '{pos_label}' "
|
||||
"does not match any label in the training data. Labels found: {labels}")
|
||||
E921 = ("The method 'set_output' can only be called on components that have "
|
||||
"a Model with a 'resize_output' attribute. Otherwise, the output "
|
||||
"layer can not be dynamically changed.")
|
||||
|
|
|
@ -56,7 +56,12 @@ subword_features = true
|
|||
@Language.factory(
|
||||
"textcat",
|
||||
assigns=["doc.cats"],
|
||||
default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
|
||||
default_config={
|
||||
"labels": [],
|
||||
"threshold": 0.5,
|
||||
"positive_label": None,
|
||||
"model": DEFAULT_TEXTCAT_MODEL,
|
||||
},
|
||||
scores=[
|
||||
"cats_score",
|
||||
"cats_score_desc",
|
||||
|
@ -74,8 +79,9 @@ def make_textcat(
|
|||
nlp: Language,
|
||||
name: str,
|
||||
model: Model[List[Doc], List[Floats2d]],
|
||||
labels: Iterable[str],
|
||||
labels: List[str],
|
||||
threshold: float,
|
||||
positive_label: Optional[str],
|
||||
) -> "TextCategorizer":
|
||||
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
||||
over a whole document. It can learn one or more labels, and the labels can
|
||||
|
@ -88,8 +94,16 @@ def make_textcat(
|
|||
labels (list): A list of categories to learn. If empty, the model infers the
|
||||
categories from the data.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
|
||||
"""
|
||||
return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold)
|
||||
return TextCategorizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
labels=labels,
|
||||
threshold=threshold,
|
||||
positive_label=positive_label,
|
||||
)
|
||||
|
||||
|
||||
class TextCategorizer(Pipe):
|
||||
|
@ -104,8 +118,9 @@ class TextCategorizer(Pipe):
|
|||
model: Model,
|
||||
name: str = "textcat",
|
||||
*,
|
||||
labels: Iterable[str],
|
||||
labels: List[str],
|
||||
threshold: float,
|
||||
positive_label: Optional[str],
|
||||
) -> None:
|
||||
"""Initialize a text categorizer.
|
||||
|
||||
|
@ -113,8 +128,9 @@ class TextCategorizer(Pipe):
|
|||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels (Iterable[str]): The labels to use.
|
||||
labels (List[str]): The labels to use.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
positive_label (Optional[str]): The positive label for a binary task with exclusive classes, None otherwise.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/textcategorizer#init
|
||||
"""
|
||||
|
@ -122,7 +138,11 @@ class TextCategorizer(Pipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {"labels": labels, "threshold": threshold}
|
||||
cfg = {
|
||||
"labels": labels,
|
||||
"threshold": threshold,
|
||||
"positive_label": positive_label,
|
||||
}
|
||||
self.cfg = dict(cfg)
|
||||
|
||||
@property
|
||||
|
@ -131,10 +151,10 @@ class TextCategorizer(Pipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/textcategorizer#labels
|
||||
"""
|
||||
return tuple(self.cfg.setdefault("labels", []))
|
||||
return tuple(self.cfg["labels"])
|
||||
|
||||
@labels.setter
|
||||
def labels(self, value: Iterable[str]) -> None:
|
||||
def labels(self, value: List[str]) -> None:
|
||||
self.cfg["labels"] = tuple(value)
|
||||
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int = 128) -> Iterator[Doc]:
|
||||
|
@ -353,17 +373,10 @@ class TextCategorizer(Pipe):
|
|||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
||||
def score(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
positive_label: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
"""Score a batch of examples.
|
||||
|
||||
examples (Iterable[Example]): The examples to score.
|
||||
positive_label (str): Optional positive label.
|
||||
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/textcategorizer#score
|
||||
|
@ -374,7 +387,7 @@ class TextCategorizer(Pipe):
|
|||
"cats",
|
||||
labels=self.labels,
|
||||
multi_label=self.model.attrs["multi_label"],
|
||||
positive_label=positive_label,
|
||||
positive_label=self.cfg["positive_label"],
|
||||
threshold=self.cfg["threshold"],
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -10,6 +10,7 @@ from spacy.tokens import Doc
|
|||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
|
||||
from ..util import make_tempdir
|
||||
from ...cli.train import verify_textcat_config
|
||||
from ...training import Example
|
||||
|
||||
|
||||
|
@ -130,7 +131,10 @@ def test_overfitting_IO():
|
|||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
# Set exclusive labels
|
||||
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}})
|
||||
textcat = nlp.add_pipe(
|
||||
"textcat",
|
||||
config={"model": {"exclusive_classes": True}, "positive_label": "POSITIVE"},
|
||||
)
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
|
@ -159,7 +163,7 @@ def test_overfitting_IO():
|
|||
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.001)
|
||||
|
||||
# Test scoring
|
||||
scores = nlp.evaluate(train_examples, scorer_cfg={"positive_label": "POSITIVE"})
|
||||
scores = nlp.evaluate(train_examples)
|
||||
assert scores["cats_micro_f"] == 1.0
|
||||
assert scores["cats_score"] == 1.0
|
||||
assert "cats_score_desc" in scores
|
||||
|
@ -194,3 +198,29 @@ def test_textcat_configs(textcat_config):
|
|||
for i in range(5):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
|
||||
|
||||
def test_positive_class():
|
||||
nlp = English()
|
||||
pipe_config = {"positive_label": "POS", "labels": ["POS", "NEG"]}
|
||||
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||
assert textcat.labels == ("POS", "NEG")
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
|
||||
|
||||
def test_positive_class_not_present():
|
||||
nlp = English()
|
||||
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING"]}
|
||||
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||
assert textcat.labels == ("SOME", "THING")
|
||||
with pytest.raises(ValueError):
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
|
||||
|
||||
def test_positive_class_not_binary():
|
||||
nlp = English()
|
||||
pipe_config = {"positive_label": "POS", "labels": ["SOME", "THING", "POS"]}
|
||||
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
||||
assert textcat.labels == ("SOME", "THING", "POS")
|
||||
with pytest.raises(ValueError):
|
||||
verify_textcat_config(nlp, pipe_config)
|
||||
|
|
|
@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
|
|||
# See issue #1105
|
||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5)
|
||||
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5, positive_label=None)
|
||||
textcat.to_bytes(exclude=["vocab"])
|
||||
|
||||
|
||||
|
|
|
@ -37,9 +37,10 @@ architectures and their arguments and hyperparameters.
|
|||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise and by default. ~~Optional[str]~~ |
|
||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
|
||||
```python
|
||||
|
@ -60,7 +61,7 @@ architectures and their arguments and hyperparameters.
|
|||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import TextCategorizer
|
||||
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5)
|
||||
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5, positive_label="POS")
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
|
@ -68,13 +69,14 @@ shortcut for this and instantiate the component using its string name and
|
|||
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ---------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `labels` | The labels to use. ~~Iterable[str]~~ |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `positive_label` | The positive label for a binary task with exclusive classes, None otherwise. ~~Optional[str]~~ |
|
||||
|
||||
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user