mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
define threshold for scoring textcat in TextCat config (#6055)
* define threshold for scoring textcat in TextCat config * fix unit test and documentation
This commit is contained in:
parent
ab270364f1
commit
744df9814a
|
@ -56,7 +56,7 @@ subword_features = true
|
|||
@Language.factory(
|
||||
"textcat",
|
||||
assigns=["doc.cats"],
|
||||
default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL},
|
||||
default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
|
||||
scores=[
|
||||
"cats_score",
|
||||
"cats_score_desc",
|
||||
|
@ -75,6 +75,7 @@ def make_textcat(
|
|||
name: str,
|
||||
model: Model[List[Doc], List[Floats2d]],
|
||||
labels: Iterable[str],
|
||||
threshold: float,
|
||||
) -> "TextCategorizer":
|
||||
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
||||
over a whole document. It can learn one or more labels, and the labels can
|
||||
|
@ -86,8 +87,9 @@ def make_textcat(
|
|||
scores for each category.
|
||||
labels (list): A list of categories to learn. If empty, the model infers the
|
||||
categories from the data.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
"""
|
||||
return TextCategorizer(nlp.vocab, model, name, labels=labels)
|
||||
return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold)
|
||||
|
||||
|
||||
class TextCategorizer(Pipe):
|
||||
|
@ -103,6 +105,7 @@ class TextCategorizer(Pipe):
|
|||
name: str = "textcat",
|
||||
*,
|
||||
labels: Iterable[str],
|
||||
threshold: float,
|
||||
) -> None:
|
||||
"""Initialize a text categorizer.
|
||||
|
||||
|
@ -111,6 +114,7 @@ class TextCategorizer(Pipe):
|
|||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels (Iterable[str]): The labels to use.
|
||||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/textcategorizer#init
|
||||
"""
|
||||
|
@ -118,7 +122,7 @@ class TextCategorizer(Pipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self._rehearsal_model = None
|
||||
cfg = {"labels": labels}
|
||||
cfg = {"labels": labels, "threshold": threshold}
|
||||
self.cfg = dict(cfg)
|
||||
|
||||
@property
|
||||
|
@ -371,5 +375,6 @@ class TextCategorizer(Pipe):
|
|||
labels=self.labels,
|
||||
multi_label=self.model.attrs["multi_label"],
|
||||
positive_label=positive_label,
|
||||
threshold=self.cfg["threshold"],
|
||||
**kwargs,
|
||||
)
|
||||
|
|
|
@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab):
|
|||
# See issue #1105
|
||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"])
|
||||
textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5)
|
||||
textcat.to_bytes(exclude=["vocab"])
|
||||
|
||||
|
||||
|
|
|
@ -30,15 +30,17 @@ architectures and their arguments and hyperparameters.
|
|||
> from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
||||
> config = {
|
||||
> "labels": [],
|
||||
> "threshold": 0.5,
|
||||
> "model": DEFAULT_TEXTCAT_MODEL,
|
||||
> }
|
||||
> nlp.add_pipe("textcat", config=config)
|
||||
> ```
|
||||
|
||||
| Setting | Description |
|
||||
| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
| Setting | Description |
|
||||
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||
|
||||
```python
|
||||
%%GITHUB_SPACY/spacy/pipeline/textcat.py
|
||||
|
@ -58,7 +60,7 @@ architectures and their arguments and hyperparameters.
|
|||
>
|
||||
> # Construction from class
|
||||
> from spacy.pipeline import TextCategorizer
|
||||
> textcat = TextCategorizer(nlp.vocab, model)
|
||||
> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5)
|
||||
> ```
|
||||
|
||||
Create a new pipeline instance. In your application, you would normally use a
|
||||
|
@ -72,6 +74,7 @@ shortcut for this and instantiate the component using its string name and
|
|||
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||
| _keyword-only_ | |
|
||||
| `labels` | The labels to use. ~~Iterable[str]~~ |
|
||||
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||
|
||||
## TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user