textcat/textcat_multilabel: add store_activations option

This commit is contained in:
Daniël de Kok 2022-06-22 20:29:47 +02:00
parent 42526a6fa1
commit 8772b9ccc4
3 changed files with 83 additions and 5 deletions

View File

@ -1,4 +1,4 @@
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any, Union
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
from thinc.types import Floats2d from thinc.types import Floats2d
import numpy import numpy
@ -75,6 +75,7 @@ subword_features = true
"threshold": 0.5, "threshold": 0.5,
"model": DEFAULT_SINGLE_TEXTCAT_MODEL, "model": DEFAULT_SINGLE_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_scorer.v1"}, "scorer": {"@scorers": "spacy.textcat_scorer.v1"},
"store_activations": False,
}, },
default_score_weights={ default_score_weights={
"cats_score": 1.0, "cats_score": 1.0,
@ -96,6 +97,7 @@ def make_textcat(
model: Model[List[Doc], List[Floats2d]], model: Model[List[Doc], List[Floats2d]],
threshold: float, threshold: float,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]],
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories """Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered over a whole document. It can learn one or more labels, and the labels are considered
@ -106,7 +108,14 @@ def make_textcat(
threshold (float): Cutoff to consider a prediction "positive". threshold (float): Cutoff to consider a prediction "positive".
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
""" """
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer) return TextCategorizer(
nlp.vocab,
model,
name,
threshold=threshold,
scorer=scorer,
store_activations=store_activations,
)
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
@ -137,6 +146,7 @@ class TextCategorizer(TrainablePipe):
*, *,
threshold: float, threshold: float,
scorer: Optional[Callable] = textcat_score, scorer: Optional[Callable] = textcat_score,
store_activations=False,
) -> None: ) -> None:
"""Initialize a text categorizer for single-label classification. """Initialize a text categorizer for single-label classification.
@ -157,6 +167,7 @@ class TextCategorizer(TrainablePipe):
cfg = {"labels": [], "threshold": threshold, "positive_label": None} cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
self.store_activations = store_activations
@property @property
def support_missing_values(self): def support_missing_values(self):
@ -208,6 +219,9 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#set_annotations DOCS: https://spacy.io/api/textcategorizer#set_annotations
""" """
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {}
if "probs" in self.store_activations:
doc.activations[self.name]["probs"] = scores[i]
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j]) doc.cats[label] = float(scores[i, j])
@ -398,3 +412,7 @@ class TextCategorizer(TrainablePipe):
for ex in examples: for ex in examples:
if list(ex.reference.cats.values()).count(1.0) > 1: if list(ex.reference.cats.values()).count(1.0) > 1:
raise ValueError(Errors.E895.format(value=ex.reference.cats)) raise ValueError(Errors.E895.format(value=ex.reference.cats))
@property
def activations(self):
return ["probs"]

View File

@ -1,4 +1,4 @@
from typing import Iterable, Optional, Dict, List, Callable, Any from typing import Iterable, Optional, Dict, List, Callable, Any, Union
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import Model, Config from thinc.api import Model, Config
@ -75,6 +75,7 @@ subword_features = true
"threshold": 0.5, "threshold": 0.5,
"model": DEFAULT_MULTI_TEXTCAT_MODEL, "model": DEFAULT_MULTI_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"}, "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"},
"store_activations": False,
}, },
default_score_weights={ default_score_weights={
"cats_score": 1.0, "cats_score": 1.0,
@ -96,6 +97,7 @@ def make_multilabel_textcat(
model: Model[List[Doc], List[Floats2d]], model: Model[List[Doc], List[Floats2d]],
threshold: float, threshold: float,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]],
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories """Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered over a whole document. It can learn one or more labels, and the labels are considered
@ -107,7 +109,12 @@ def make_multilabel_textcat(
threshold (float): Cutoff to consider a prediction "positive". threshold (float): Cutoff to consider a prediction "positive".
""" """
return MultiLabel_TextCategorizer( return MultiLabel_TextCategorizer(
nlp.vocab, model, name, threshold=threshold, scorer=scorer nlp.vocab,
model,
name,
threshold=threshold,
scorer=scorer,
store_activations=store_activations,
) )
@ -139,6 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
*, *,
threshold: float, threshold: float,
scorer: Optional[Callable] = textcat_multilabel_score, scorer: Optional[Callable] = textcat_multilabel_score,
store_activations=False,
) -> None: ) -> None:
"""Initialize a text categorizer for multi-label classification. """Initialize a text categorizer for multi-label classification.
@ -157,6 +165,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
cfg = {"labels": [], "threshold": threshold} cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
self.store_activations = store_activations
@property @property
def support_missing_values(self): def support_missing_values(self):

View File

@ -1,3 +1,4 @@
from typing import cast
import random import random
import numpy.random import numpy.random
@ -11,7 +12,7 @@ from spacy import util
from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TextCategorizer from spacy.pipeline import TextCategorizer, TrainablePipe
from spacy.pipeline.textcat import single_label_bow_config from spacy.pipeline.textcat import single_label_bow_config
from spacy.pipeline.textcat import single_label_cnn_config from spacy.pipeline.textcat import single_label_cnn_config
from spacy.pipeline.textcat import single_label_default_config from spacy.pipeline.textcat import single_label_default_config
@ -871,3 +872,53 @@ def test_textcat_multi_threshold():
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0}) scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0})
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
def test_store_activations():
fix_random_seed(0)
nlp = English()
textcat = cast(TrainablePipe, nlp.add_pipe("textcat"))
train_examples = []
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
nlp.initialize(get_examples=lambda: train_examples)
nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat"].keys())) == 0
textcat.store_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
textcat.store_activations = ["probs"]
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
def test_store_activations_multi():
fix_random_seed(0)
nlp = English()
textcat = cast(TrainablePipe, nlp.add_pipe("textcat_multilabel"))
train_examples = []
for text, annotations in TRAIN_DATA_MULTI_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
nlp.initialize(get_examples=lambda: train_examples)
nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
textcat.store_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
textcat.store_activations = ["probs"]
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)