diff --git a/spacy/errors.py b/spacy/errors.py index 14010565b..cda9d47f0 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -209,6 +209,7 @@ class Warnings(metaclass=ErrorsWithCodes): "Only the last span group will be loaded under " "Doc.spans['{group_name}']. Skipping span group with values: " "{group_values}") + W121 = ("Activation '{activation}' is unknown for pipe '{pipe_name}'") class Errors(metaclass=ErrorsWithCodes): @@ -934,6 +935,7 @@ class Errors(metaclass=ErrorsWithCodes): E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}") E1042 = ("Function was called with `{arg1}`={arg1_values} and " "`{arg2}`={arg2_values} but these arguments are conflicting.") + E1043 = ("store_activations attribute must be set to List[str] or bool") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index d7a9569f1..d0c62649f 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -217,7 +217,7 @@ class Morphologizer(Tagger): assert len(label_sample) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=doc_sample, Y=label_sample) - def set_annotations(self, docs, scores_guesses): + def set_annotations(self, docs, activations): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. @@ -225,7 +225,7 @@ class Morphologizer(Tagger): DOCS: https://spacy.io/api/morphologizer#set_annotations """ - _, batch_tag_ids = scores_guesses + batch_tag_ids = activations["guesses"] if isinstance(docs, Doc): docs = [docs] cdef Doc doc diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 2caeee5c1..6f9b57196 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -1,5 +1,5 @@ # cython: infer_types=True, profile=True, binding=True -from typing import Optional, Callable +from typing import Optional, Callable, List, Union from itertools import islice import srsly @@ -46,7 +46,12 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"] }, default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, ) -def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool): +def make_senter(nlp: Language, + name: str, + model: Model, + overwrite: bool, + scorer: Optional[Callable], + store_activations: Union[bool, List[str]]): return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations) @@ -114,7 +119,7 @@ class SentenceRecognizer(Tagger): def label_data(self): return None - def set_annotations(self, docs, scores_guesses): + def set_annotations(self, docs, activations): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. @@ -122,17 +127,15 @@ class SentenceRecognizer(Tagger): DOCS: https://spacy.io/api/sentencerecognizer#set_annotations """ - _, batch_tag_ids = scores_guesses + batch_tag_ids = activations["guesses"] if isinstance(docs, Doc): docs = [docs] cdef Doc doc cdef bint overwrite = self.cfg["overwrite"] for i, doc in enumerate(docs): - if self.store_activations: - doc.activations[self.name] = { - "probs": scores_guesses[0][i], - "guesses": scores_guesses[1][i], - } + doc.activations[self.name] = {} + for activation in self.store_activations: + doc.activations[self.name][activation] = activations[activation][i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() @@ -199,3 +202,7 @@ class SentenceRecognizer(Tagger): def add_label(self, label, values=None): raise NotImplementedError + + @property + def activations(self): + return ["probs", "guesses"] diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 3505cbaaf..d5374a6e3 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -1,5 +1,5 @@ # cython: infer_types=True, profile=True, binding=True -from typing import Callable, Optional +from typing import Callable, List, Optional, Union import numpy import srsly from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config @@ -61,7 +61,7 @@ def make_tagger( overwrite: bool, scorer: Optional[Callable], neg_prefix: str, - store_activations: bool, + store_activations: Union[bool, List[str]], ): """Construct a part-of-speech tagger component. @@ -148,12 +148,12 @@ class Tagger(TrainablePipe): n_labels = len(self.labels) guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs] assert len(guesses) == len(docs) - return guesses, guesses + return {"probs": guesses, "guesses": guesses} scores = self.model.predict(docs) assert len(scores) == len(docs), (len(scores), len(docs)) guesses = self._scores2guesses(scores) assert len(guesses) == len(docs) - return scores, guesses + return {"probs": scores, "guesses": guesses} def _scores2guesses(self, scores): guesses = [] @@ -164,7 +164,7 @@ class Tagger(TrainablePipe): guesses.append(doc_guesses) return guesses - def set_annotations(self, docs, scores_guesses): + def set_annotations(self, docs, activations): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. @@ -172,7 +172,7 @@ class Tagger(TrainablePipe): DOCS: https://spacy.io/api/tagger#set_annotations """ - _, batch_tag_ids = scores_guesses + batch_tag_ids = activations["guesses"] if isinstance(docs, Doc): docs = [docs] cdef Doc doc @@ -180,11 +180,9 @@ class Tagger(TrainablePipe): cdef bint overwrite = self.cfg["overwrite"] labels = self.labels for i, doc in enumerate(docs): - if self.store_activations: - doc.activations[self.name] = { - "probs": scores_guesses[0][i], - "guesses": scores_guesses[1][i], - } + doc.activations[self.name] = {} + for activation in self.store_activations: + doc.activations[self.name][activation] = activations[activation][i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() @@ -338,3 +336,7 @@ class Tagger(TrainablePipe): self.cfg["labels"].append(label) self.vocab.strings.add(label) return 1 + + @property + def activations(self): + return ["probs", "guesses"] diff --git a/spacy/pipeline/trainable_pipe.pxd b/spacy/pipeline/trainable_pipe.pxd index 411f0819d..40dab33d6 100644 --- a/spacy/pipeline/trainable_pipe.pxd +++ b/spacy/pipeline/trainable_pipe.pxd @@ -6,4 +6,4 @@ cdef class TrainablePipe(Pipe): cdef public object model cdef public object cfg cdef public object scorer - cdef public bint store_activations + cdef object _store_activations diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx index 76b0733cf..499839c4a 100644 --- a/spacy/pipeline/trainable_pipe.pyx +++ b/spacy/pipeline/trainable_pipe.pyx @@ -2,11 +2,12 @@ from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable import srsly from thinc.api import set_dropout_rate, Model, Optimizer +import warnings from ..tokens.doc cimport Doc from ..training import validate_examples -from ..errors import Errors +from ..errors import Errors, Warnings from .pipe import Pipe, deserialize_config from .. import util from ..vocab import Vocab @@ -342,3 +343,29 @@ cdef class TrainablePipe(Pipe): deserialize["model"] = load_model util.from_disk(path, deserialize, exclude) return self + + @property + def activations(self): + raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="activations", name=self.name)) + + @property + def store_activations(self): + return self._store_activations + + @store_activations.setter + def store_activations(self, activations): + known_activations = self.activations + if isinstance(activations, list): + self._store_activations = [] + for activation in activations: + if activation in known_activations: + self._store_activations.append(activation) + else: + warnings.warn(Warnings.W121.format(activation=activation, pipe_name=self.name)) + elif isinstance(activations, bool): + if activations: + self._store_activations = list(known_activations) + else: + self._store_activations = [] + else: + raise ValueError(Errors.E1043) diff --git a/spacy/tests/pipeline/test_senter.py b/spacy/tests/pipeline/test_senter.py index 91ceacf00..cd194115f 100644 --- a/spacy/tests/pipeline/test_senter.py +++ b/spacy/tests/pipeline/test_senter.py @@ -117,9 +117,14 @@ def test_store_activations(): nlp.initialize(get_examples=lambda: train_examples) senter.store_activations = True - doc = nlp("This is a test.") assert "senter" in doc.activations assert set(doc.activations["senter"].keys()) == {"guesses", "probs"} assert doc.activations["senter"]["probs"].shape == (5, 2) assert doc.activations["senter"]["guesses"].shape == (5,) + + senter.store_activations = ["probs"] + doc = nlp("This is a test.") + assert "senter" in doc.activations + assert set(doc.activations["senter"].keys()) == {"probs"} + assert doc.activations["senter"]["probs"].shape == (5, 2) diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index 6a8f75648..ac1050f48 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -223,14 +223,17 @@ def test_store_activations(): nlp.initialize(get_examples=lambda: train_examples) tagger.store_activations = True - doc = nlp("This is a test.") - assert "tagger" in doc.activations assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"} assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) assert doc.activations["tagger"]["guesses"].shape == (5,) + tagger.store_activations = ["probs"] + doc = nlp("This is a test.") + assert set(doc.activations["tagger"].keys()) == {"probs"} + assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) + def test_tagger_requires_labels(): nlp = English()