mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
define threshold for scoring textcat in TextCat config (#6055)
* define threshold for scoring textcat in TextCat config * fix unit test and documentation
This commit is contained in:
parent
ab270364f1
commit
744df9814a
|
@ -56,7 +56,7 @@ subword_features = true
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"textcat",
|
"textcat",
|
||||||
assigns=["doc.cats"],
|
assigns=["doc.cats"],
|
||||||
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
|
default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
|
||||||
scores=[
|
scores=[
|
||||||
"cats_score",
|
"cats_score",
|
||||||
"cats_score_desc",
|
"cats_score_desc",
|
||||||
|
@ -75,6 +75,7 @@ def make_textcat(
|
||||||
name: str,
|
name: str,
|
||||||
model: Model[List[Doc], List[Floats2d]],
|
model: Model[List[Doc], List[Floats2d]],
|
||||||
labels: Iterable[str],
|
labels: Iterable[str],
|
||||||
|
threshold: float,
|
||||||
) -> "TextCategorizer":
|
) -> "TextCategorizer":
|
||||||
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
||||||
over a whole document. It can learn one or more labels, and the labels can
|
over a whole document. It can learn one or more labels, and the labels can
|
||||||
|
@ -86,8 +87,9 @@ def make_textcat(
|
||||||
scores for each category.
|
scores for each category.
|
||||||
labels (list): A list of categories to learn. If empty, the model infers the
|
labels (list): A list of categories to learn. If empty, the model infers the
|
||||||
categories from the data.
|
categories from the data.
|
||||||
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
"""
|
"""
|
||||||
return TextCategorizer(nlp.vocab, model, name, labels=labels)
|
return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold)
|
||||||
|
|
||||||
|
|
||||||
class TextCategorizer(Pipe):
|
class TextCategorizer(Pipe):
|
||||||
|
@ -103,6 +105,7 @@ class TextCategorizer(Pipe):
|
||||||
name: str = "textcat",
|
name: str = "textcat",
|
||||||
*,
|
*,
|
||||||
labels: Iterable[str],
|
labels: Iterable[str],
|
||||||
|
threshold: float,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer.
|
"""Initialize a text categorizer.
|
||||||
|
|
||||||
|
@ -111,6 +114,7 @@ class TextCategorizer(Pipe):
|
||||||
name (str): The component instance name, used to add entries to the
|
name (str): The component instance name, used to add entries to the
|
||||||
losses during training.
|
losses during training.
|
||||||
labels (Iterable[str]): The labels to use.
|
labels (Iterable[str]): The labels to use.
|
||||||
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#init
|
DOCS: https://nightly.spacy.io/api/textcategorizer#init
|
||||||
"""
|
"""
|
||||||
|
@ -118,7 +122,7 @@ class TextCategorizer(Pipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
cfg = {"labels": labels}
|
cfg = {"labels": labels, "threshold": threshold}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -371,5 +375,6 @@ class TextCategorizer(Pipe):
|
||||||
labels=self.labels,
|
labels=self.labels,
|
||||||
multi_label=self.model.attrs["multi_label"],
|
multi_label=self.model.attrs["multi_label"],
|
||||||
positive_label=positive_label,
|
positive_label=positive_label,
|
||||||
|
threshold=self.cfg["threshold"],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
|
@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
|
||||||
# See issue #1105
|
# See issue #1105
|
||||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||||
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"])
|
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5)
|
||||||
textcat.to_bytes(exclude=["vocab"])
|
textcat.to_bytes(exclude=["vocab"])
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -30,15 +30,17 @@ architectures and their arguments and hyperparameters.
|
||||||
> from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
> from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
||||||
> config = {
|
> config = {
|
||||||
> "labels": [],
|
> "labels": [],
|
||||||
|
> "threshold": 0.5,
|
||||||
> "model": DEFAULT_TEXTCAT_MODEL,
|
> "model": DEFAULT_TEXTCAT_MODEL,
|
||||||
> }
|
> }
|
||||||
> nlp.add_pipe("textcat", config=config)
|
> nlp.add_pipe("textcat", config=config)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Setting | Description |
|
| Setting | Description |
|
||||||
| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
|
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||||
|
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||||
|
|
||||||
```python
|
```python
|
||||||
%%GITHUB_SPACY/spacy/pipeline/textcat.py
|
%%GITHUB_SPACY/spacy/pipeline/textcat.py
|
||||||
|
@ -58,7 +60,7 @@ architectures and their arguments and hyperparameters.
|
||||||
>
|
>
|
||||||
> # Construction from class
|
> # Construction from class
|
||||||
> from spacy.pipeline import TextCategorizer
|
> from spacy.pipeline import TextCategorizer
|
||||||
> textcat = TextCategorizer(nlp.vocab, model)
|
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5)
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
Create a new pipeline instance. In your application, you would normally use a
|
Create a new pipeline instance. In your application, you would normally use a
|
||||||
|
@ -72,6 +74,7 @@ shortcut for this and instantiate the component using its string name and
|
||||||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `labels` | The labels to use. ~~Iterable[str]~~ |
|
| `labels` | The labels to use. ~~Iterable[str]~~ |
|
||||||
|
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||||
|
|
||||||
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user