define threshold for scoring textcat in TextCat config (#6055)

* define threshold for scoring textcat in TextCat config

* fix unit test and documentation
This commit is contained in:
Sofie Van Landeghem 2020-09-13 14:15:52 +02:00 committed by GitHub
parent ab270364f1
commit 744df9814a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 9 deletions

View File

@ -56,7 +56,7 @@ subword_features = true
@Language.factory(
"textcat",
assigns=["doc.cats"],
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
scores=[
"cats_score",
"cats_score_desc",
@ -75,6 +75,7 @@ def make_textcat(
name: str,
model: Model[List[Doc], List[Floats2d]],
labels: Iterable[str],
threshold: float,
) -> "TextCategorizer":
"""Create a TextCategorizer compoment. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels can
@ -86,8 +87,9 @@ def make_textcat(
scores for each category.
labels (list): A list of categories to learn. If empty, the model infers the
categories from the data.
threshold (float): Cutoff to consider a prediction "positive".
"""
return TextCategorizer(nlp.vocab, model, name, labels=labels)
return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold)
class TextCategorizer(Pipe):
@ -103,6 +105,7 @@ class TextCategorizer(Pipe):
name: str = "textcat",
*,
labels: Iterable[str],
threshold: float,
) -> None:
"""Initialize a text categorizer.
@ -111,6 +114,7 @@ class TextCategorizer(Pipe):
name (str): The component instance name, used to add entries to the
losses during training.
labels (Iterable[str]): The labels to use.
threshold (float): Cutoff to consider a prediction "positive".
DOCS: https://nightly.spacy.io/api/textcategorizer#init
"""
@ -118,7 +122,7 @@ class TextCategorizer(Pipe):
self.model = model
self.name = name
self._rehearsal_model = None
cfg = {"labels": labels}
cfg = {"labels": labels, "threshold": threshold}
self.cfg = dict(cfg)
@property
@ -371,5 +375,6 @@ class TextCategorizer(Pipe):
labels=self.labels,
multi_label=self.model.attrs["multi_label"],
positive_label=positive_label,
threshold=self.cfg["threshold"],
**kwargs,
)

View File

@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
# See issue #1105
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"])
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5)
textcat.to_bytes(exclude=["vocab"])

View File

@ -30,15 +30,17 @@ architectures and their arguments and hyperparameters.
> from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
> config = {
> "labels": [],
> "threshold": 0.5,
> "model": DEFAULT_TEXTCAT_MODEL,
> }
> nlp.add_pipe("textcat", config=config)
> ```
| Setting | Description |
| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
| Setting | Description |
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
```python
%%GITHUB_SPACY/spacy/pipeline/textcat.py
@ -58,7 +60,7 @@ architectures and their arguments and hyperparameters.
>
> # Construction from class
> from spacy.pipeline import TextCategorizer
> textcat = TextCategorizer(nlp.vocab, model)
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5)
> ```
Create a new pipeline instance. In your application, you would normally use a
@ -72,6 +74,7 @@ shortcut for this and instantiate the component using its string name and
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | |
| `labels` | The labels to use. ~~Iterable[str]~~ |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
## TextCategorizer.\_\_call\_\_ {#call tag="method"}