textcat/textcat_multilabel: add store_activations option

This commit is contained in:
Daniël de Kok 2022-06-22 20:29:47 +02:00
parent 42526a6fa1
commit 8772b9ccc4
3 changed files with 83 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any, Union
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
from thinc.types import Floats2d
import numpy
@ -75,6 +75,7 @@ subword_features = true
"threshold": 0.5,
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
"store_activations": False,
},
default_score_weights={
"cats_score": 1.0,
@ -96,6 +97,7 @@ def make_textcat(
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
store_activations: Union[bool, List[str]],
) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered
@ -106,7 +108,14 @@ def make_textcat(
threshold (float): Cutoff to consider a prediction "positive".
scorer (Optional[Callable]): The scoring method.
"""
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
return TextCategorizer(
nlp.vocab,
model,
name,
threshold=threshold,
scorer=scorer,
store_activations=store_activations,
)
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
@ -137,6 +146,7 @@ class TextCategorizer(TrainablePipe):
*,
threshold: float,
scorer: Optional[Callable] = textcat_score,
store_activations=False,
) -> None:
"""Initialize a text categorizer for single-label classification.
@ -157,6 +167,7 @@ class TextCategorizer(TrainablePipe):
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg)
self.scorer = scorer
self.store_activations = store_activations
@property
def support_missing_values(self):
@ -208,6 +219,9 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#set_annotations
"""
for i, doc in enumerate(docs):
doc.activations[self.name] = {}
if "probs" in self.store_activations:
doc.activations[self.name]["probs"] = scores[i]
for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j])
@ -398,3 +412,7 @@ class TextCategorizer(TrainablePipe):
for ex in examples:
if list(ex.reference.cats.values()).count(1.0) > 1:
raise ValueError(Errors.E895.format(value=ex.reference.cats))
@property
def activations(self):
return ["probs"]

View File

@ -1,4 +1,4 @@
from typing import Iterable, Optional, Dict, List, Callable, Any
from typing import Iterable, Optional, Dict, List, Callable, Any, Union
from thinc.types import Floats2d
from thinc.api import Model, Config
@ -75,6 +75,7 @@ subword_features = true
"threshold": 0.5,
"model": DEFAULT_MULTI_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"},
"store_activations": False,
},
default_score_weights={
"cats_score": 1.0,
@ -96,6 +97,7 @@ def make_multilabel_textcat(
model: Model[List[Doc], List[Floats2d]],
threshold: float,
scorer: Optional[Callable],
store_activations: Union[bool, List[str]],
) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered
@ -107,7 +109,12 @@ def make_multilabel_textcat(
threshold (float): Cutoff to consider a prediction "positive".
"""
return MultiLabel_TextCategorizer(
nlp.vocab, model, name, threshold=threshold, scorer=scorer
nlp.vocab,
model,
name,
threshold=threshold,
scorer=scorer,
store_activations=store_activations,
)
@ -139,6 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
*,
threshold: float,
scorer: Optional[Callable] = textcat_multilabel_score,
store_activations=False,
) -> None:
"""Initialize a text categorizer for multi-label classification.
@ -157,6 +165,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg)
self.scorer = scorer
self.store_activations = store_activations
@property
def support_missing_values(self):

View File

@ -1,3 +1,4 @@
from typing import cast
import random
import numpy.random
@ -11,7 +12,7 @@ from spacy import util
from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat
from spacy.lang.en import English
from spacy.language import Language
from spacy.pipeline import TextCategorizer
from spacy.pipeline import TextCategorizer, TrainablePipe
from spacy.pipeline.textcat import single_label_bow_config
from spacy.pipeline.textcat import single_label_cnn_config
from spacy.pipeline.textcat import single_label_default_config
@ -871,3 +872,53 @@ def test_textcat_multi_threshold():
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0})
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
def test_store_activations():
fix_random_seed(0)
nlp = English()
textcat = cast(TrainablePipe, nlp.add_pipe("textcat"))
train_examples = []
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
nlp.initialize(get_examples=lambda: train_examples)
nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat"].keys())) == 0
textcat.store_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
textcat.store_activations = ["probs"]
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
def test_store_activations_multi():
fix_random_seed(0)
nlp = English()
textcat = cast(TrainablePipe, nlp.add_pipe("textcat_multilabel"))
train_examples = []
for text, annotations in TRAIN_DATA_MULTI_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
nlp.initialize(get_examples=lambda: train_examples)
nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
textcat.store_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
textcat.store_activations = ["probs"]
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)