mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-04 06:16:33 +03:00
textcat/textcat_multilabel: add store_activations option
This commit is contained in:
parent
42526a6fa1
commit
8772b9ccc4
|
@ -1,4 +1,4 @@
|
||||||
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any
|
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any, Union
|
||||||
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
import numpy
|
import numpy
|
||||||
|
@ -75,6 +75,7 @@ subword_features = true
|
||||||
"threshold": 0.5,
|
"threshold": 0.5,
|
||||||
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
"model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||||
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
"scorer": {"@scorers": "spacy.textcat_scorer.v1"},
|
||||||
|
"store_activations": False,
|
||||||
},
|
},
|
||||||
default_score_weights={
|
default_score_weights={
|
||||||
"cats_score": 1.0,
|
"cats_score": 1.0,
|
||||||
|
@ -96,6 +97,7 @@ def make_textcat(
|
||||||
model: Model[List[Doc], List[Floats2d]],
|
model: Model[List[Doc], List[Floats2d]],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
|
store_activations: Union[bool, List[str]],
|
||||||
) -> "TextCategorizer":
|
) -> "TextCategorizer":
|
||||||
"""Create a TextCategorizer component. The text categorizer predicts categories
|
"""Create a TextCategorizer component. The text categorizer predicts categories
|
||||||
over a whole document. It can learn one or more labels, and the labels are considered
|
over a whole document. It can learn one or more labels, and the labels are considered
|
||||||
|
@ -106,7 +108,14 @@ def make_textcat(
|
||||||
threshold (float): Cutoff to consider a prediction "positive".
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
scorer (Optional[Callable]): The scoring method.
|
scorer (Optional[Callable]): The scoring method.
|
||||||
"""
|
"""
|
||||||
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer)
|
return TextCategorizer(
|
||||||
|
nlp.vocab,
|
||||||
|
model,
|
||||||
|
name,
|
||||||
|
threshold=threshold,
|
||||||
|
scorer=scorer,
|
||||||
|
store_activations=store_activations,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||||
|
@ -137,6 +146,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
*,
|
*,
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable] = textcat_score,
|
scorer: Optional[Callable] = textcat_score,
|
||||||
|
store_activations=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer for single-label classification.
|
"""Initialize a text categorizer for single-label classification.
|
||||||
|
|
||||||
|
@ -157,6 +167,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def support_missing_values(self):
|
def support_missing_values(self):
|
||||||
|
@ -208,6 +219,9 @@ class TextCategorizer(TrainablePipe):
|
||||||
DOCS: https://spacy.io/api/textcategorizer#set_annotations
|
DOCS: https://spacy.io/api/textcategorizer#set_annotations
|
||||||
"""
|
"""
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
doc.activations[self.name] = {}
|
||||||
|
if "probs" in self.store_activations:
|
||||||
|
doc.activations[self.name]["probs"] = scores[i]
|
||||||
for j, label in enumerate(self.labels):
|
for j, label in enumerate(self.labels):
|
||||||
doc.cats[label] = float(scores[i, j])
|
doc.cats[label] = float(scores[i, j])
|
||||||
|
|
||||||
|
@ -398,3 +412,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
for ex in examples:
|
for ex in examples:
|
||||||
if list(ex.reference.cats.values()).count(1.0) > 1:
|
if list(ex.reference.cats.values()).count(1.0) > 1:
|
||||||
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def activations(self):
|
||||||
|
return ["probs"]
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Iterable, Optional, Dict, List, Callable, Any
|
from typing import Iterable, Optional, Dict, List, Callable, Any, Union
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.api import Model, Config
|
from thinc.api import Model, Config
|
||||||
|
|
||||||
|
@ -75,6 +75,7 @@ subword_features = true
|
||||||
"threshold": 0.5,
|
"threshold": 0.5,
|
||||||
"model": DEFAULT_MULTI_TEXTCAT_MODEL,
|
"model": DEFAULT_MULTI_TEXTCAT_MODEL,
|
||||||
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"},
|
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"},
|
||||||
|
"store_activations": False,
|
||||||
},
|
},
|
||||||
default_score_weights={
|
default_score_weights={
|
||||||
"cats_score": 1.0,
|
"cats_score": 1.0,
|
||||||
|
@ -96,6 +97,7 @@ def make_multilabel_textcat(
|
||||||
model: Model[List[Doc], List[Floats2d]],
|
model: Model[List[Doc], List[Floats2d]],
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
|
store_activations: Union[bool, List[str]],
|
||||||
) -> "TextCategorizer":
|
) -> "TextCategorizer":
|
||||||
"""Create a TextCategorizer component. The text categorizer predicts categories
|
"""Create a TextCategorizer component. The text categorizer predicts categories
|
||||||
over a whole document. It can learn one or more labels, and the labels are considered
|
over a whole document. It can learn one or more labels, and the labels are considered
|
||||||
|
@ -107,7 +109,12 @@ def make_multilabel_textcat(
|
||||||
threshold (float): Cutoff to consider a prediction "positive".
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
"""
|
"""
|
||||||
return MultiLabel_TextCategorizer(
|
return MultiLabel_TextCategorizer(
|
||||||
nlp.vocab, model, name, threshold=threshold, scorer=scorer
|
nlp.vocab,
|
||||||
|
model,
|
||||||
|
name,
|
||||||
|
threshold=threshold,
|
||||||
|
scorer=scorer,
|
||||||
|
store_activations=store_activations,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -139,6 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
*,
|
*,
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable] = textcat_multilabel_score,
|
scorer: Optional[Callable] = textcat_multilabel_score,
|
||||||
|
store_activations=False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer for multi-label classification.
|
"""Initialize a text categorizer for multi-label classification.
|
||||||
|
|
||||||
|
@ -157,6 +165,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
cfg = {"labels": [], "threshold": threshold}
|
cfg = {"labels": [], "threshold": threshold}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def support_missing_values(self):
|
def support_missing_values(self):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import cast
|
||||||
import random
|
import random
|
||||||
|
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
@ -11,7 +12,7 @@ from spacy import util
|
||||||
from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat
|
from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.pipeline import TextCategorizer
|
from spacy.pipeline import TextCategorizer, TrainablePipe
|
||||||
from spacy.pipeline.textcat import single_label_bow_config
|
from spacy.pipeline.textcat import single_label_bow_config
|
||||||
from spacy.pipeline.textcat import single_label_cnn_config
|
from spacy.pipeline.textcat import single_label_cnn_config
|
||||||
from spacy.pipeline.textcat import single_label_default_config
|
from spacy.pipeline.textcat import single_label_default_config
|
||||||
|
@ -871,3 +872,53 @@ def test_textcat_multi_threshold():
|
||||||
|
|
||||||
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0})
|
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0})
|
||||||
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_activations():
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
textcat = cast(TrainablePipe, nlp.add_pipe("textcat"))
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
nO = textcat.model.get_dim("nO")
|
||||||
|
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert len(list(doc.activations["textcat"].keys())) == 0
|
||||||
|
|
||||||
|
textcat.store_activations = True
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||||
|
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||||
|
|
||||||
|
textcat.store_activations = ["probs"]
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||||
|
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_activations_multi():
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
textcat = cast(TrainablePipe, nlp.add_pipe("textcat_multilabel"))
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA_MULTI_LABEL:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
nO = textcat.model.get_dim("nO")
|
||||||
|
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
|
||||||
|
|
||||||
|
textcat.store_activations = True
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||||
|
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||||
|
|
||||||
|
textcat.store_activations = ["probs"]
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||||
|
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user