diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 4be6f580d..22d1de08f 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -56,7 +56,7 @@ subword_features = true @Language.factory( "textcat", assigns=["doc.cats"], - default_config={"labels": [], "model": DEFAULT_TEXTCAT_MODEL}, + default_config={"labels": [], "threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL}, scores=[ "cats_score", "cats_score_desc", @@ -75,6 +75,7 @@ def make_textcat( name: str, model: Model[List[Doc], List[Floats2d]], labels: Iterable[str], + threshold: float, ) -> "TextCategorizer": """Create a TextCategorizer compoment. The text categorizer predicts categories over a whole document. It can learn one or more labels, and the labels can @@ -86,8 +87,9 @@ def make_textcat( scores for each category. labels (list): A list of categories to learn. If empty, the model infers the categories from the data. + threshold (float): Cutoff to consider a prediction "positive". """ - return TextCategorizer(nlp.vocab, model, name, labels=labels) + return TextCategorizer(nlp.vocab, model, name, labels=labels, threshold=threshold) class TextCategorizer(Pipe): @@ -103,6 +105,7 @@ class TextCategorizer(Pipe): name: str = "textcat", *, labels: Iterable[str], + threshold: float, ) -> None: """Initialize a text categorizer. @@ -111,6 +114,7 @@ class TextCategorizer(Pipe): name (str): The component instance name, used to add entries to the losses during training. labels (Iterable[str]): The labels to use. + threshold (float): Cutoff to consider a prediction "positive". DOCS: https://nightly.spacy.io/api/textcategorizer#init """ @@ -118,7 +122,7 @@ class TextCategorizer(Pipe): self.model = model self.name = name self._rehearsal_model = None - cfg = {"labels": labels} + cfg = {"labels": labels, "threshold": threshold} self.cfg = dict(cfg) @property @@ -371,5 +375,6 @@ class TextCategorizer(Pipe): labels=self.labels, multi_label=self.model.attrs["multi_label"], positive_label=positive_label, + threshold=self.cfg["threshold"], **kwargs, ) diff --git a/spacy/tests/serialize/test_serialize_pipeline.py b/spacy/tests/serialize/test_serialize_pipeline.py index db62f6569..e621aebd8 100644 --- a/spacy/tests/serialize/test_serialize_pipeline.py +++ b/spacy/tests/serialize/test_serialize_pipeline.py @@ -136,7 +136,7 @@ def test_serialize_textcat_empty(en_vocab): # See issue #1105 cfg = {"model": DEFAULT_TEXTCAT_MODEL} model = registry.make_from_config(cfg, validate=True)["model"] - textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"]) + textcat = TextCategorizer(en_vocab, model, labels=["ENTITY", "ACTION", "MODIFIER"], threshold=0.5) textcat.to_bytes(exclude=["vocab"]) diff --git a/website/docs/api/textcategorizer.md b/website/docs/api/textcategorizer.md index cc20d6fd2..9bdc6324f 100644 --- a/website/docs/api/textcategorizer.md +++ b/website/docs/api/textcategorizer.md @@ -30,15 +30,17 @@ architectures and their arguments and hyperparameters. > from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL > config = { > "labels": [], +> "threshold": 0.5, > "model": DEFAULT_TEXTCAT_MODEL, > } > nlp.add_pipe("textcat", config=config) > ``` -| Setting | Description | -| -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ | -| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ | +| Setting | Description | +| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `labels` | A list of categories to learn. If empty, the model infers the categories from the data. Defaults to `[]`. ~~Iterable[str]~~ | +| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ | +| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ | ```python %%GITHUB_SPACY/spacy/pipeline/textcat.py @@ -58,7 +60,7 @@ architectures and their arguments and hyperparameters. > > # Construction from class > from spacy.pipeline import TextCategorizer -> textcat = TextCategorizer(nlp.vocab, model) +> textcat = TextCategorizer(nlp.vocab, model, labels=[], threshold=0.5) > ``` Create a new pipeline instance. In your application, you would normally use a @@ -72,6 +74,7 @@ shortcut for this and instantiate the component using its string name and | `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | | _keyword-only_ | | | `labels` | The labels to use. ~~Iterable[str]~~ | +| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ | ## TextCategorizer.\_\_call\_\_ {#call tag="method"}