From aea53378dc819c572cc53974f202d93fcf2a27aa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 29 Aug 2022 16:33:08 +0200 Subject: [PATCH] Make the `TrainablePipe.store_activations` property a bool This means that we can also bring back `store_activations` setter. --- spacy/pipeline/edit_tree_lemmatizer.py | 16 ++++++------ spacy/pipeline/entity_linker.py | 26 +++++++++---------- spacy/pipeline/morphologizer.pyx | 16 ++++++------ spacy/pipeline/senter.pyx | 16 ++++++------ spacy/pipeline/spancat.py | 14 +++++----- spacy/pipeline/tagger.pyx | 16 ++++++------ spacy/pipeline/textcat.py | 13 +++++----- spacy/pipeline/textcat_multilabel.py | 9 +++---- spacy/pipeline/trainable_pipe.pxd | 2 +- spacy/pipeline/trainable_pipe.pyx | 19 +++----------- .../pipeline/test_edit_tree_lemmatizer.py | 9 ++----- spacy/tests/pipeline/test_entity_linker.py | 13 ++-------- spacy/tests/pipeline/test_morphologizer.py | 10 ++----- spacy/tests/pipeline/test_senter.py | 10 ++----- spacy/tests/pipeline/test_spancat.py | 9 ++----- spacy/tests/pipeline/test_tagger.py | 9 ++----- spacy/tests/pipeline/test_textcat.py | 18 +++---------- 17 files changed, 81 insertions(+), 144 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 3af39b1d1..f97ff875f 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -65,7 +65,7 @@ def make_edit_tree_lemmatizer( overwrite: bool, top_k: int, scorer: Optional[Callable], - store_activations: Union[bool, List[str]], + store_activations: bool, ): """Construct an EditTreeLemmatizer component.""" return EditTreeLemmatizer( @@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe): overwrite: bool = False, top_k: int = 1, scorer: Optional[Callable] = lemmatizer_score, - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ): """ Construct an edit tree lemmatizer. @@ -109,8 +109,7 @@ class EditTreeLemmatizer(TrainablePipe): frequency in the training data. overwrite (bool): overwrite existing lemma annotations. top_k (int): try to apply at most the k most probable edit trees. - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations are: "probs" and "guesses". + store_activations (bool): store model activations in Doc when annotating. """ self.vocab = vocab self.model = model @@ -125,7 +124,7 @@ class EditTreeLemmatizer(TrainablePipe): self.cfg: Dict[str, Any] = {"labels": []} self.scorer = scorer - self.set_store_activations(store_activations) + self.store_activations = store_activations def get_loss( self, examples: Iterable[Example], scores: List[Floats2d] @@ -202,9 +201,10 @@ class EditTreeLemmatizer(TrainablePipe): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): batch_tree_ids = activations["guesses"] for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + if self.store_activations: + doc.activations[self.name] = {} + for act_name, acts in activations.items(): + doc.activations[self.name][act_name] = acts[i] doc_tree_ids = batch_tree_ids[i] if hasattr(doc_tree_ids, "get"): doc_tree_ids = doc_tree_ids.get() diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 844ab6185..f6aca2487 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -85,7 +85,7 @@ def make_entity_linker( scorer: Optional[Callable], use_gold_ents: bool, threshold: Optional[float] = None, - store_activations: Union[bool, List[str]], + store_activations: bool, ): """Construct an EntityLinker component. @@ -104,8 +104,7 @@ def make_entity_linker( component must provide entity annotations. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations are: "ents" and "scores". + store_activations (bool): store model activations in Doc when annotating. """ if not model.attrs.get("include_span_maker", False): @@ -174,7 +173,7 @@ class EntityLinker(TrainablePipe): scorer: Optional[Callable] = entity_linker_score, use_gold_ents: bool, threshold: Optional[float] = None, - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ) -> None: """Initialize an entity linker. @@ -223,7 +222,7 @@ class EntityLinker(TrainablePipe): self.scorer = scorer self.use_gold_ents = use_gold_ents self.threshold = threshold - self.set_store_activations(store_activations) + self.store_activations = store_activations def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): """Define the KB of this pipe by providing a function that will @@ -551,12 +550,13 @@ class EntityLinker(TrainablePipe): i = 0 overwrite = self.cfg["overwrite"] for j, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - # We only copy activations that are Ragged. - doc.activations[self.name][activation] = cast( - Ragged, activations[activation][j] - ) + if self.store_activations: + doc.activations[self.name] = {} + for act_name, acts in activations.items(): + if act_name != "kb_ids": + # We only copy activations that are Ragged. + doc.activations[self.name][act_name] = cast(Ragged, acts[j]) + for ent in doc.ents: kb_id = kb_ids[i] i += 1 @@ -668,7 +668,7 @@ class EntityLinker(TrainablePipe): doc_scores: List[Floats1d], doc_ents: List[Ints1d], ): - if len(self.store_activations) == 0: + if not self.store_activations: return ops = self.model.ops lengths = ops.asarray1i([s.shape[0] for s in doc_scores]) @@ -683,7 +683,7 @@ class EntityLinker(TrainablePipe): scores: Sequence[float], ents: Sequence[int], ): - if len(self.store_activations) == 0: + if not self.store_activations: return ops = self.model.ops doc_scores.append(ops.asarray1f(scores)) diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 68bc76ad7..73b19dd0d 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -69,7 +69,7 @@ def make_morphologizer( overwrite: bool, extend: bool, scorer: Optional[Callable], - store_activations: Union[bool, List[str]], + store_activations: bool, ): return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer, store_activations=store_activations) @@ -104,7 +104,7 @@ class Morphologizer(Tagger): overwrite: bool = BACKWARD_OVERWRITE, extend: bool = BACKWARD_EXTEND, scorer: Optional[Callable] = morphologizer_score, - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ): """Initialize a morphologizer. @@ -115,8 +115,7 @@ class Morphologizer(Tagger): scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_token_attr for the attributes "pos" and "morph" and Scorer.score_token_attr_per_feat for the attribute "morph". - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations are: "probs" and "guesses". + store_activations (bool): store model activations in Doc when annotating. DOCS: https://spacy.io/api/morphologizer#init """ @@ -136,7 +135,7 @@ class Morphologizer(Tagger): } self.cfg = dict(sorted(cfg.items())) self.scorer = scorer - self.set_store_activations(store_activations) + self.store_activations = store_activations @property def labels(self): @@ -250,9 +249,10 @@ class Morphologizer(Tagger): # to allocate a compatible container out of the iterable. labels = tuple(self.labels) for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + if self.store_activations: + doc.activations[self.name] = {} + for act_name, acts in activations.items(): + doc.activations[self.name][act_name] = acts[i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index a80bcb0f3..499f2866d 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -52,7 +52,7 @@ def make_senter(nlp: Language, model: Model, overwrite: bool, scorer: Optional[Callable], - store_activations: Union[bool, List[str]]): + store_activations: bool): return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations) @@ -83,7 +83,7 @@ class SentenceRecognizer(Tagger): *, overwrite=BACKWARD_OVERWRITE, scorer=senter_score, - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ): """Initialize a sentence recognizer. @@ -93,8 +93,7 @@ class SentenceRecognizer(Tagger): losses during training. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_spans for the attribute "sents". - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations are: "probs" and "guesses". + store_activations (bool): store model activations in Doc when annotating. DOCS: https://spacy.io/api/sentencerecognizer#init """ @@ -104,7 +103,7 @@ class SentenceRecognizer(Tagger): self._rehearsal_model = None self.cfg = {"overwrite": overwrite} self.scorer = scorer - self.set_store_activations(store_activations) + self.store_activations = store_activations @property def labels(self): @@ -136,9 +135,10 @@ class SentenceRecognizer(Tagger): cdef Doc doc cdef bint overwrite = self.cfg["overwrite"] for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + if self.store_activations: + doc.activations[self.name] = {} + for act_name, acts in activations.items(): + doc.activations[self.name][act_name] = acts[i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 4484f7577..2179c3c23 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -120,7 +120,7 @@ def make_spancat( scorer: Optional[Callable], threshold: float, max_positive: Optional[int], - store_activations: Union[bool, List[str]], + store_activations: bool, ) -> "SpanCategorizer": """Create a SpanCategorizer component. The span categorizer consists of two parts: a suggester function that proposes candidate spans, and a labeller @@ -141,8 +141,7 @@ def make_spancat( 0.5. max_positive (Optional[int]): Maximum number of labels to consider positive per span. Defaults to None, indicating no limit. - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations are: "indices" and "scores". + store_activations (bool): store model activations in Doc when annotating. """ return SpanCategorizer( nlp.vocab, @@ -192,7 +191,7 @@ class SpanCategorizer(TrainablePipe): threshold: float = 0.5, max_positive: Optional[int] = None, scorer: Optional[Callable] = spancat_score, - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ) -> None: """Initialize the span categorizer. vocab (Vocab): The shared vocabulary. @@ -225,7 +224,7 @@ class SpanCategorizer(TrainablePipe): self.model = model self.name = name self.scorer = scorer - self.set_store_activations(store_activations) + self.store_activations = store_activations @property def key(self) -> str: @@ -317,10 +316,9 @@ class SpanCategorizer(TrainablePipe): offset = 0 for i, doc in enumerate(docs): indices_i = indices[i].dataXd - doc.activations[self.name] = {} - if "indices" in self.store_activations: + if self.store_activations: + doc.activations[self.name] = {} doc.activations[self.name]["indices"] = indices_i - if "scores" in self.store_activations: doc.activations[self.name]["scores"] = scores[ offset : offset + indices.lengths[i] ] diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 13f80409b..13513f6de 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -61,7 +61,7 @@ def make_tagger( overwrite: bool, scorer: Optional[Callable], neg_prefix: str, - store_activations: Union[bool, List[str]], + store_activations: bool, ): """Construct a part-of-speech tagger component. @@ -97,7 +97,7 @@ class Tagger(TrainablePipe): overwrite=BACKWARD_OVERWRITE, scorer=tagger_score, neg_prefix="!", - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ): """Initialize a part-of-speech tagger. @@ -107,8 +107,7 @@ class Tagger(TrainablePipe): losses during training. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_token_attr for the attribute "tag". - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations are: "probs" and "guesses". + store_activations (bool): store model activations in Doc when annotating. DOCS: https://spacy.io/api/tagger#init """ @@ -119,7 +118,7 @@ class Tagger(TrainablePipe): cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix} self.cfg = dict(sorted(cfg.items())) self.scorer = scorer - self.set_store_activations(store_activations) + self.store_activations = store_activations @property def labels(self): @@ -183,9 +182,10 @@ class Tagger(TrainablePipe): cdef bint overwrite = self.cfg["overwrite"] labels = self.labels for i, doc in enumerate(docs): - doc.activations[self.name] = {} - for activation in self.store_activations: - doc.activations[self.name][activation] = activations[activation][i] + if self.store_activations: + doc.activations[self.name] = {} + for act_name, acts in activations.items(): + doc.activations[self.name][act_name] = acts[i] doc_tag_ids = batch_tag_ids[i] if hasattr(doc_tag_ids, "get"): doc_tag_ids = doc_tag_ids.get() diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index fc883fb68..7789c88bb 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -97,7 +97,7 @@ def make_textcat( model: Model[List[Doc], List[Floats2d]], threshold: float, scorer: Optional[Callable], - store_activations: Union[bool, List[str]], + store_activations: bool, ) -> "TextCategorizer": """Create a TextCategorizer component. The text categorizer predicts categories over a whole document. It can learn one or more labels, and the labels are considered @@ -107,8 +107,7 @@ def make_textcat( scores for each category. threshold (float): Cutoff to consider a prediction "positive". scorer (Optional[Callable]): The scoring method. - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations is: "probs". + store_activations (bool): store model activations in Doc when annotating. """ return TextCategorizer( nlp.vocab, @@ -148,7 +147,7 @@ class TextCategorizer(TrainablePipe): *, threshold: float, scorer: Optional[Callable] = textcat_score, - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ) -> None: """Initialize a text categorizer for single-label classification. @@ -169,7 +168,7 @@ class TextCategorizer(TrainablePipe): cfg = {"labels": [], "threshold": threshold, "positive_label": None} self.cfg = dict(cfg) self.scorer = scorer - self.set_store_activations(store_activations) + self.store_activations = store_activations @property def support_missing_values(self): @@ -224,8 +223,8 @@ class TextCategorizer(TrainablePipe): """ probs = activations["probs"] for i, doc in enumerate(docs): - doc.activations[self.name] = {} - if "probs" in self.store_activations: + if self.store_activations: + doc.activations[self.name] = {} doc.activations[self.name]["probs"] = probs[i] for j, label in enumerate(self.labels): doc.cats[label] = float(probs[i, j]) diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index db523d024..7ac56fba3 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -97,7 +97,7 @@ def make_multilabel_textcat( model: Model[List[Doc], List[Floats2d]], threshold: float, scorer: Optional[Callable], - store_activations: Union[bool, List[str]], + store_activations: bool, ) -> "TextCategorizer": """Create a TextCategorizer component. The text categorizer predicts categories over a whole document. It can learn one or more labels, and the labels are considered @@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): *, threshold: float, scorer: Optional[Callable] = textcat_multilabel_score, - store_activations: Union[bool, List[str]] = False, + store_activations: bool = False, ) -> None: """Initialize a text categorizer for multi-label classification. @@ -155,8 +155,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): name (str): The component instance name, used to add entries to the losses during training. threshold (float): Cutoff to consider a prediction "positive". - store_activations (Union[bool, List[str]]): Model activations to store in - Doc when annotating. supported activations is: "probs". + store_activations (bool): store model activations in Doc when annotating. DOCS: https://spacy.io/api/textcategorizer#init """ @@ -167,7 +166,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): cfg = {"labels": [], "threshold": threshold} self.cfg = dict(cfg) self.scorer = scorer - self.set_store_activations(store_activations) + self.store_activations = store_activations @property def support_missing_values(self): diff --git a/spacy/pipeline/trainable_pipe.pxd b/spacy/pipeline/trainable_pipe.pxd index 40dab33d6..6ca4a8c89 100644 --- a/spacy/pipeline/trainable_pipe.pxd +++ b/spacy/pipeline/trainable_pipe.pxd @@ -6,4 +6,4 @@ cdef class TrainablePipe(Pipe): cdef public object model cdef public object cfg cdef public object scorer - cdef object _store_activations + cdef bint _store_activations diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx index 69fc6ca4f..2bd081b07 100644 --- a/spacy/pipeline/trainable_pipe.pyx +++ b/spacy/pipeline/trainable_pipe.pyx @@ -352,19 +352,6 @@ cdef class TrainablePipe(Pipe): def store_activations(self): return self._store_activations - def set_store_activations(self, activations): - known_activations = self.activations - if isinstance(activations, list): - self._store_activations = [] - for activation in activations: - if activation in known_activations: - self._store_activations.append(activation) - else: - warnings.warn(Warnings.W400.format(activation=activation, pipe_name=self.name)) - elif isinstance(activations, bool): - if activations: - self._store_activations = list(known_activations) - else: - self._store_activations = [] - else: - raise ValueError(Errors.E1400) + @store_activations.setter + def store_activations(self, store_activations: bool): + self._store_activations = store_activations diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index 5905b4583..9c49e6bcf 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -293,15 +293,10 @@ def test_store_activations(): nO = lemmatizer.model.get_dim("nO") doc = nlp("This is a test.") - assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0 + assert "trainable_lemmatizer" not in doc.activations - lemmatizer.set_store_activations(True) + lemmatizer.store_activations = True doc = nlp("This is a test.") assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"] assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO) assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,) - - lemmatizer.set_store_activations(["probs"]) - doc = nlp("This is a test.") - assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"] - assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 2e5ae3621..c33b213fa 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1225,9 +1225,9 @@ def test_store_activations(): ruler.add_patterns(patterns) doc = nlp("Russ Cochran was a publisher") - assert len(doc.activations["entity_linker"].keys()) == 0 + assert "entity_linker" not in doc.activations - entity_linker.set_store_activations(True) + entity_linker.store_activations = True doc = nlp("Russ Cochran was a publisher") assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"} ents = doc.activations["entity_linker"]["ents"] @@ -1240,12 +1240,3 @@ def test_store_activations(): assert scores.data.shape == (2, 1) assert scores.data.dtype == "float32" assert scores.lengths.shape == (1,) - - entity_linker.set_store_activations(["scores"]) - doc = nlp("Russ Cochran was a publisher") - assert set(doc.activations["entity_linker"].keys()) == {"scores"} - scores = doc.activations["entity_linker"]["scores"] - assert isinstance(scores, Ragged) - assert scores.data.shape == (2, 1) - assert scores.data.dtype == "float32" - assert scores.lengths.shape == (1,) diff --git a/spacy/tests/pipeline/test_morphologizer.py b/spacy/tests/pipeline/test_morphologizer.py index a5907d8cf..a92db4fdf 100644 --- a/spacy/tests/pipeline/test_morphologizer.py +++ b/spacy/tests/pipeline/test_morphologizer.py @@ -211,17 +211,11 @@ def test_store_activations(): nlp.initialize(get_examples=lambda: train_examples) doc = nlp("This is a test.") - assert len(list(doc.activations["morphologizer"].keys())) == 0 + assert "morphologizer" not in doc.activations - morphologizer.set_store_activations(True) + morphologizer.store_activations = True doc = nlp("This is a test.") assert "morphologizer" in doc.activations assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"} assert doc.activations["morphologizer"]["probs"].shape == (5, 6) assert doc.activations["morphologizer"]["guesses"].shape == (5,) - - morphologizer.set_store_activations(["probs"]) - doc = nlp("This is a test.") - assert "morphologizer" in doc.activations - assert set(doc.activations["morphologizer"].keys()) == {"probs"} - assert doc.activations["morphologizer"]["probs"].shape == (5, 6) diff --git a/spacy/tests/pipeline/test_senter.py b/spacy/tests/pipeline/test_senter.py index 2a9fc77f3..34b4e60f9 100644 --- a/spacy/tests/pipeline/test_senter.py +++ b/spacy/tests/pipeline/test_senter.py @@ -118,17 +118,11 @@ def test_store_activations(): nO = senter.model.get_dim("nO") doc = nlp("This is a test.") - assert len(list(doc.activations["senter"].keys())) == 0 + assert "senter" not in doc.activations - senter.set_store_activations(True) + senter.store_activations = True doc = nlp("This is a test.") assert "senter" in doc.activations assert set(doc.activations["senter"].keys()) == {"guesses", "probs"} assert doc.activations["senter"]["probs"].shape == (5, nO) assert doc.activations["senter"]["guesses"].shape == (5,) - - senter.set_store_activations(["probs"]) - doc = nlp("This is a test.") - assert "senter" in doc.activations - assert set(doc.activations["senter"].keys()) == {"probs"} - assert doc.activations["senter"]["probs"].shape == (5, 2) diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index e34a6e69f..0fab5a9c4 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -432,15 +432,10 @@ def test_store_activations(): assert set(spancat.labels) == {"LOC", "PERSON"} doc = nlp("This is a test.") - assert len(list(doc.activations["spancat"].keys())) == 0 + assert "spancat" not in doc.activations - spancat.set_store_activations(True) + spancat.store_activations = True doc = nlp("This is a test.") assert set(doc.activations["spancat"].keys()) == {"indices", "scores"} assert doc.activations["spancat"]["indices"].shape == (12, 2) assert doc.activations["spancat"]["scores"].shape == (12, nO) - - spancat.set_store_activations(["scores"]) - doc = nlp("This is a test.") - assert set(doc.activations["spancat"].keys()) == {"scores"} - assert doc.activations["spancat"]["scores"].shape == (12, nO) diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index fca8bc262..fa698eac6 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -223,20 +223,15 @@ def test_store_activations(): nlp.initialize(get_examples=lambda: train_examples) doc = nlp("This is a test.") - assert len(list(doc.activations["tagger"].keys())) == 0 + assert "tagger" not in doc.activations - tagger.set_store_activations(True) + tagger.store_activations = True doc = nlp("This is a test.") assert "tagger" in doc.activations assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"} assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) assert doc.activations["tagger"]["guesses"].shape == (5,) - tagger.set_store_activations(["probs"]) - doc = nlp("This is a test.") - assert set(doc.activations["tagger"].keys()) == {"probs"} - assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) - def test_tagger_requires_labels(): nlp = English() diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index be94e18c8..97edd7a6c 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -886,14 +886,9 @@ def test_store_activations(): nO = textcat.model.get_dim("nO") doc = nlp("This is a test.") - assert len(list(doc.activations["textcat"].keys())) == 0 + assert "textcat" not in doc.activations - textcat.set_store_activations(True) - doc = nlp("This is a test.") - assert list(doc.activations["textcat"].keys()) == ["probs"] - assert doc.activations["textcat"]["probs"].shape == (nO,) - - textcat.set_store_activations(["probs"]) + textcat.store_activations = True doc = nlp("This is a test.") assert list(doc.activations["textcat"].keys()) == ["probs"] assert doc.activations["textcat"]["probs"].shape == (nO,) @@ -911,14 +906,9 @@ def test_store_activations_multi(): nO = textcat.model.get_dim("nO") doc = nlp("This is a test.") - assert len(list(doc.activations["textcat_multilabel"].keys())) == 0 + assert "textcat_multilabel" not in doc.activations - textcat.set_store_activations(True) - doc = nlp("This is a test.") - assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"] - assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,) - - textcat.set_store_activations(["probs"]) + textcat.store_activations = True doc = nlp("This is a test.") assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"] assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)