define threshold for scoring textcat in TextCat config (#6055)

* define threshold for scoring textcat in TextCat config

* fix unit test and documentation
This commit is contained in:
Sofie Van Landeghem 2020-09-13 14:15:52 +02:00 committed by GitHub
parent ab270364f1
commit 744df9814a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 17 additions and 9 deletions

View File

@ -56,7 +56,7 @@ subword_features = true
@Language.factory( @Language.factory(
"textcat", "textcat",
assigns=["doc.cats"], assigns=["doc.cats"],
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL}, default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
scores=[ scores=[
"cats_score", "cats_score",
"cats_score_desc", "cats_score_desc",
@ -75,6 +75,7 @@ def make_textcat(
name: str, name: str,
model: Model[List[Doc], List[Floats2d]], model: Model[List[Doc], List[Floats2d]],
labels: Iterable[str], labels: Iterable[str],
threshold: float,
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer compoment. The text categorizer predicts categories """Create a TextCategorizer compoment. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels can over a whole document. It can learn one or more labels, and the labels can
@ -86,8 +87,9 @@ def make_textcat(
scores for each category. scores for each category.
labels (list): A list of categories to learn. If empty, the model infers the labels (list): A list of categories to learn. If empty, the model infers the
categories from the data. categories from the data.
threshold (float): Cutoff to consider a prediction "positive".
""" """
return TextCategorizer(nlp.vocab, model, name, labels=labels) return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold)
class TextCategorizer(Pipe): class TextCategorizer(Pipe):
@ -103,6 +105,7 @@ class TextCategorizer(Pipe):
name: str = "textcat", name: str = "textcat",
*, *,
labels: Iterable[str], labels: Iterable[str],
threshold: float,
) -> None: ) -> None:
"""Initialize a text categorizer. """Initialize a text categorizer.
@ -111,6 +114,7 @@ class TextCategorizer(Pipe):
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
labels (Iterable[str]): The labels to use. labels (Iterable[str]): The labels to use.
threshold (float): Cutoff to consider a prediction "positive".
DOCS: https://nightly.spacy.io/api/textcategorizer#init DOCS: https://nightly.spacy.io/api/textcategorizer#init
""" """
@ -118,7 +122,7 @@ class TextCategorizer(Pipe):
self.model = model self.model = model
self.name = name self.name = name
self._rehearsal_model = None self._rehearsal_model = None
cfg = {"labels": labels} cfg = {"labels": labels, "threshold": threshold}
self.cfg = dict(cfg) self.cfg = dict(cfg)
@property @property
@ -371,5 +375,6 @@ class TextCategorizer(Pipe):
labels=self.labels, labels=self.labels,
multi_label=self.model.attrs["multi_label"], multi_label=self.model.attrs["multi_label"],
positive_label=positive_label, positive_label=positive_label,
threshold=self.cfg["threshold"],
**kwargs, **kwargs,
) )

View File

@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
# See issue #1105 # See issue #1105
cfg = {"model": DEFAULT_TEXTCAT_MODEL} cfg = {"model": DEFAULT_TEXTCAT_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.make_from_config(cfg, validate=True)["model"]
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"]) textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5)
textcat.to_bytes(exclude=["vocab"]) textcat.to_bytes(exclude=["vocab"])

View File

@ -30,15 +30,17 @@ architectures and their arguments and hyperparameters.
> from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL > from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
> config = { > config = {
> "labels": [], > "labels": [],
> "threshold": 0.5,
> "model": DEFAULT_TEXTCAT_MODEL, > "model": DEFAULT_TEXTCAT_MODEL,
> } > }
> nlp.add_pipe("textcat", config=config) > nlp.add_pipe("textcat", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |
| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ | | `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ | | `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
```python ```python
%%GITHUB_SPACY/spacy/pipeline/textcat.py %%GITHUB_SPACY/spacy/pipeline/textcat.py
@ -58,7 +60,7 @@ architectures and their arguments and hyperparameters.
> >
> # Construction from class > # Construction from class
> from spacy.pipeline import TextCategorizer > from spacy.pipeline import TextCategorizer
> textcat = TextCategorizer(nlp.vocab, model) > textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5)
> ``` > ```
Create a new pipeline instance. In your application, you would normally use a Create a new pipeline instance. In your application, you would normally use a
@ -72,6 +74,7 @@ shortcut for this and instantiate the component using its string name and
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | | `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `labels` | The labels to use. ~~Iterable[str]~~ | | `labels` | The labels to use. ~~Iterable[str]~~ |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
## TextCategorizer.\_\_call\_\_ {#call tag="method"} ## TextCategorizer.\_\_call\_\_ {#call tag="method"}