diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index bc3f127fc..3e015de07 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -1,4 +1,4 @@ -from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any +from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any, Union from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config from thinc.types import Floats2d import numpy @@ -75,6 +75,7 @@ subword_features = true "threshold": 0.5, "model": DEFAULT_SINGLE_TEXTCAT_MODEL, "scorer": {"@scorers": "spacy.textcat_scorer.v1"}, + "store_activations": False, }, default_score_weights={ "cats_score": 1.0, @@ -96,6 +97,7 @@ def make_textcat( model: Model[List[Doc], List[Floats2d]], threshold: float, scorer: Optional[Callable], + store_activations: Union[bool, List[str]], ) -> "TextCategorizer": """Create a TextCategorizer component. The text categorizer predicts categories over a whole document. It can learn one or more labels, and the labels are considered @@ -106,7 +108,14 @@ def make_textcat( threshold (float): Cutoff to consider a prediction "positive". scorer (Optional[Callable]): The scoring method. """ - return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer) + return TextCategorizer( + nlp.vocab, + model, + name, + threshold=threshold, + scorer=scorer, + store_activations=store_activations, + ) def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: @@ -137,6 +146,7 @@ class TextCategorizer(TrainablePipe): *, threshold: float, scorer: Optional[Callable] = textcat_score, + store_activations=False, ) -> None: """Initialize a text categorizer for single-label classification. @@ -157,6 +167,7 @@ class TextCategorizer(TrainablePipe): cfg = {"labels": [], "threshold": threshold, "positive_label": None} self.cfg = dict(cfg) self.scorer = scorer + self.store_activations = store_activations @property def support_missing_values(self): @@ -208,6 +219,9 @@ class TextCategorizer(TrainablePipe): DOCS: https://spacy.io/api/textcategorizer#set_annotations """ for i, doc in enumerate(docs): + doc.activations[self.name] = {} + if "probs" in self.store_activations: + doc.activations[self.name]["probs"] = scores[i] for j, label in enumerate(self.labels): doc.cats[label] = float(scores[i, j]) @@ -398,3 +412,7 @@ class TextCategorizer(TrainablePipe): for ex in examples: if list(ex.reference.cats.values()).count(1.0) > 1: raise ValueError(Errors.E895.format(value=ex.reference.cats)) + + @property + def activations(self): + return ["probs"] diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index e33a885f8..1bb817054 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -1,4 +1,4 @@ -from typing import Iterable, Optional, Dict, List, Callable, Any +from typing import Iterable, Optional, Dict, List, Callable, Any, Union from thinc.types import Floats2d from thinc.api import Model, Config @@ -75,6 +75,7 @@ subword_features = true "threshold": 0.5, "model": DEFAULT_MULTI_TEXTCAT_MODEL, "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"}, + "store_activations": False, }, default_score_weights={ "cats_score": 1.0, @@ -96,6 +97,7 @@ def make_multilabel_textcat( model: Model[List[Doc], List[Floats2d]], threshold: float, scorer: Optional[Callable], + store_activations: Union[bool, List[str]], ) -> "TextCategorizer": """Create a TextCategorizer component. The text categorizer predicts categories over a whole document. It can learn one or more labels, and the labels are considered @@ -107,7 +109,12 @@ def make_multilabel_textcat( threshold (float): Cutoff to consider a prediction "positive". """ return MultiLabel_TextCategorizer( - nlp.vocab, model, name, threshold=threshold, scorer=scorer + nlp.vocab, + model, + name, + threshold=threshold, + scorer=scorer, + store_activations=store_activations, ) @@ -139,6 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): *, threshold: float, scorer: Optional[Callable] = textcat_multilabel_score, + store_activations=False, ) -> None: """Initialize a text categorizer for multi-label classification. @@ -157,6 +165,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): cfg = {"labels": [], "threshold": threshold} self.cfg = dict(cfg) self.scorer = scorer + self.store_activations = store_activations @property def support_missing_values(self): diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 0bb036a33..79d0945a5 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -1,3 +1,4 @@ +from typing import cast import random import numpy.random @@ -11,7 +12,7 @@ from spacy import util from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat from spacy.lang.en import English from spacy.language import Language -from spacy.pipeline import TextCategorizer +from spacy.pipeline import TextCategorizer, TrainablePipe from spacy.pipeline.textcat import single_label_bow_config from spacy.pipeline.textcat import single_label_cnn_config from spacy.pipeline.textcat import single_label_default_config @@ -871,3 +872,53 @@ def test_textcat_multi_threshold(): scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0}) assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 + + +def test_store_activations(): + fix_random_seed(0) + nlp = English() + textcat = cast(TrainablePipe, nlp.add_pipe("textcat")) + + train_examples = [] + for text, annotations in TRAIN_DATA_SINGLE_LABEL: + train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) + nlp.initialize(get_examples=lambda: train_examples) + nO = textcat.model.get_dim("nO") + + doc = nlp("This is a test.") + assert len(list(doc.activations["textcat"].keys())) == 0 + + textcat.store_activations = True + doc = nlp("This is a test.") + assert list(doc.activations["textcat"].keys()) == ["probs"] + assert doc.activations["textcat"]["probs"].shape == (nO,) + + textcat.store_activations = ["probs"] + doc = nlp("This is a test.") + assert list(doc.activations["textcat"].keys()) == ["probs"] + assert doc.activations["textcat"]["probs"].shape == (nO,) + + +def test_store_activations_multi(): + fix_random_seed(0) + nlp = English() + textcat = cast(TrainablePipe, nlp.add_pipe("textcat_multilabel")) + + train_examples = [] + for text, annotations in TRAIN_DATA_MULTI_LABEL: + train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) + nlp.initialize(get_examples=lambda: train_examples) + nO = textcat.model.get_dim("nO") + + doc = nlp("This is a test.") + assert len(list(doc.activations["textcat_multilabel"].keys())) == 0 + + textcat.store_activations = True + doc = nlp("This is a test.") + assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"] + assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,) + + textcat.store_activations = ["probs"] + doc = nlp("This is a test.") + assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"] + assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)