mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-01 04:46:38 +03:00
textcat/textcat_multilabel: add store_activations option
This commit is contained in:
parent
42526a6fa1
commit
8772b9ccc4
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
|
||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any, Union
|
||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||
from thinc.types import Floats2d
|
||||
import numpy
|
||||
|
@ -75,6 +75,7 @@ subword_features = true
|
|||
"threshold": 0.5,
|
||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
||||
"store_activations": False,
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
|
@ -96,6 +97,7 @@ def make_textcat(
|
|||
model: Model[List[Doc], List[Floats2d]],
|
||||
threshold: float,
|
||||
scorer: Optional[Callable],
|
||||
store_activations: Union[bool, List[str]],
|
||||
) -> "TextCategorizer":
|
||||
"""Create a TextCategorizer component. The text categorizer predicts categories
|
||||
over a whole document. It can learn one or more labels, and the labels are considered
|
||||
|
@ -106,7 +108,14 @@ def make_textcat(
|
|||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
"""
|
||||
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
|
||||
return TextCategorizer(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
threshold=threshold,
|
||||
scorer=scorer,
|
||||
store_activations=store_activations,
|
||||
)
|
||||
|
||||
|
||||
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||
|
@ -137,6 +146,7 @@ class TextCategorizer(TrainablePipe):
|
|||
*,
|
||||
threshold: float,
|
||||
scorer: Optional[Callable] = textcat_score,
|
||||
store_activations=False,
|
||||
) -> None:
|
||||
"""Initialize a text categorizer for single-label classification.
|
||||
|
||||
|
@ -157,6 +167,7 @@ class TextCategorizer(TrainablePipe):
|
|||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||
self.cfg = dict(cfg)
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
|
||||
@property
|
||||
def support_missing_values(self):
|
||||
|
@ -208,6 +219,9 @@ class TextCategorizer(TrainablePipe):
|
|||
DOCS: https://spacy.io/api/textcategorizer#set_annotations
|
||||
"""
|
||||
for i, doc in enumerate(docs):
|
||||
doc.activations[self.name] = {}
|
||||
if "probs" in self.store_activations:
|
||||
doc.activations[self.name]["probs"] = scores[i]
|
||||
for j, label in enumerate(self.labels):
|
||||
doc.cats[label] = float(scores[i, j])
|
||||
|
||||
|
@ -398,3 +412,7 @@ class TextCategorizer(TrainablePipe):
|
|||
for ex in examples:
|
||||
if list(ex.reference.cats.values()).count(1.0) > 1:
|
||||
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
||||
|
||||
@property
|
||||
def activations(self):
|
||||
return ["probs"]
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Iterable, Optional, Dict, List, Callable, Any
|
||||
from typing import Iterable, Optional, Dict, List, Callable, Any, Union
|
||||
from thinc.types import Floats2d
|
||||
from thinc.api import Model, Config
|
||||
|
||||
|
@ -75,6 +75,7 @@ subword_features = true
|
|||
"threshold": 0.5,
|
||||
"model": DEFAULT_MULTI_TEXTCAT_MODEL,
|
||||
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"},
|
||||
"store_activations": False,
|
||||
},
|
||||
default_score_weights={
|
||||
"cats_score": 1.0,
|
||||
|
@ -96,6 +97,7 @@ def make_multilabel_textcat(
|
|||
model: Model[List[Doc], List[Floats2d]],
|
||||
threshold: float,
|
||||
scorer: Optional[Callable],
|
||||
store_activations: Union[bool, List[str]],
|
||||
) -> "TextCategorizer":
|
||||
"""Create a TextCategorizer component. The text categorizer predicts categories
|
||||
over a whole document. It can learn one or more labels, and the labels are considered
|
||||
|
@ -107,7 +109,12 @@ def make_multilabel_textcat(
|
|||
threshold (float): Cutoff to consider a prediction "positive".
|
||||
"""
|
||||
return MultiLabel_TextCategorizer(
|
||||
nlp.vocab, model, name, threshold=threshold, scorer=scorer
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
threshold=threshold,
|
||||
scorer=scorer,
|
||||
store_activations=store_activations,
|
||||
)
|
||||
|
||||
|
||||
|
@ -139,6 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
*,
|
||||
threshold: float,
|
||||
scorer: Optional[Callable] = textcat_multilabel_score,
|
||||
store_activations=False,
|
||||
) -> None:
|
||||
"""Initialize a text categorizer for multi-label classification.
|
||||
|
||||
|
@ -157,6 +165,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
cfg = {"labels": [], "threshold": threshold}
|
||||
self.cfg = dict(cfg)
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
|
||||
@property
|
||||
def support_missing_values(self):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from typing import cast
|
||||
import random
|
||||
|
||||
import numpy.random
|
||||
|
@ -11,7 +12,7 @@ from spacy import util
|
|||
from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import TextCategorizer
|
||||
from spacy.pipeline import TextCategorizer, TrainablePipe
|
||||
from spacy.pipeline.textcat import single_label_bow_config
|
||||
from spacy.pipeline.textcat import single_label_cnn_config
|
||||
from spacy.pipeline.textcat import single_label_default_config
|
||||
|
@ -871,3 +872,53 @@ def test_textcat_multi_threshold():
|
|||
|
||||
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0})
|
||||
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
||||
|
||||
|
||||
def test_store_activations():
|
||||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
textcat = cast(TrainablePipe, nlp.add_pipe("textcat"))
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
nO = textcat.model.get_dim("nO")
|
||||
|
||||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["textcat"].keys())) == 0
|
||||
|
||||
textcat.store_activations = True
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||
|
||||
textcat.store_activations = ["probs"]
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||
|
||||
|
||||
def test_store_activations_multi():
|
||||
fix_random_seed(0)
|
||||
nlp = English()
|
||||
textcat = cast(TrainablePipe, nlp.add_pipe("textcat_multilabel"))
|
||||
|
||||
train_examples = []
|
||||
for text, annotations in TRAIN_DATA_MULTI_LABEL:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
nO = textcat.model.get_dim("nO")
|
||||
|
||||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
|
||||
|
||||
textcat.store_activations = True
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||
|
||||
textcat.store_activations = ["probs"]
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||
|
|
Loading…
Reference in New Issue
Block a user