From 51f72e41ecffec723c7ea1a8deb30cd5bb732c1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 4 Aug 2022 15:11:40 +0200 Subject: [PATCH] Replace store_activations setter by set_store_activations method Setters that take a different type than what the getter returns are still problematic for MyPy. Replace the setter by a method, so that type inference works everywhere. --- spacy/pipeline/edit_tree_lemmatizer.py | 4 ++-- spacy/pipeline/entity_linker.py | 4 ++-- spacy/pipeline/morphologizer.pyx | 4 ++-- spacy/pipeline/senter.pyx | 4 ++-- spacy/pipeline/spancat.py | 4 ++-- spacy/pipeline/tagger.pyx | 4 ++-- spacy/pipeline/textcat.py | 4 ++-- spacy/pipeline/textcat_multilabel.py | 4 ++-- spacy/pipeline/trainable_pipe.pyx | 3 +-- spacy/tests/pipeline/test_edit_tree_lemmatizer.py | 4 ++-- spacy/tests/pipeline/test_entity_linker.py | 4 ++-- spacy/tests/pipeline/test_morphologizer.py | 4 ++-- spacy/tests/pipeline/test_senter.py | 4 ++-- spacy/tests/pipeline/test_spancat.py | 5 ++--- spacy/tests/pipeline/test_tagger.py | 4 ++-- spacy/tests/pipeline/test_textcat.py | 8 ++++---- 16 files changed, 33 insertions(+), 35 deletions(-) diff --git a/spacy/pipeline/edit_tree_lemmatizer.py b/spacy/pipeline/edit_tree_lemmatizer.py index 8d7d7a1d0..3af39b1d1 100644 --- a/spacy/pipeline/edit_tree_lemmatizer.py +++ b/spacy/pipeline/edit_tree_lemmatizer.py @@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe): overwrite: bool = False, top_k: int = 1, scorer: Optional[Callable] = lemmatizer_score, - store_activations=False, + store_activations: Union[bool, List[str]] = False, ): """ Construct an edit tree lemmatizer. @@ -125,7 +125,7 @@ class EditTreeLemmatizer(TrainablePipe): self.cfg: Dict[str, Any] = {"labels": []} self.scorer = scorer - self.store_activations = store_activations # type: ignore + self.set_store_activations(store_activations) def get_loss( self, examples: Iterable[Example], scores: List[Floats2d] diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 4350bd5aa..266c1f07f 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -174,7 +174,7 @@ class EntityLinker(TrainablePipe): scorer: Optional[Callable] = entity_linker_score, use_gold_ents: bool, threshold: Optional[float] = None, - store_activations=False, + store_activations: Union[bool, List[str]] = False, ) -> None: """Initialize an entity linker. @@ -223,7 +223,7 @@ class EntityLinker(TrainablePipe): self.scorer = scorer self.use_gold_ents = use_gold_ents self.threshold = threshold - self.store_activations = store_activations + self.set_store_activations(store_activations) def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): """Define the KB of this pipe by providing a function that will diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 75a3ff3d6..0c7eacd12 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -103,7 +103,7 @@ class Morphologizer(Tagger): overwrite: bool = BACKWARD_OVERWRITE, extend: bool = BACKWARD_EXTEND, scorer: Optional[Callable] = morphologizer_score, - store_activations=False, + store_activations: Union[bool, List[str]] = False, ): """Initialize a morphologizer. @@ -135,7 +135,7 @@ class Morphologizer(Tagger): } self.cfg = dict(sorted(cfg.items())) self.scorer = scorer - self.store_activations = store_activations + self.set_store_activations(store_activations) @property def labels(self): diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 53664d7ac..1cfd6c4b1 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -82,7 +82,7 @@ class SentenceRecognizer(Tagger): *, overwrite=BACKWARD_OVERWRITE, scorer=senter_score, - store_activations=False, + store_activations: Union[bool, List[str]] = False, ): """Initialize a sentence recognizer. @@ -103,7 +103,7 @@ class SentenceRecognizer(Tagger): self._rehearsal_model = None self.cfg = {"overwrite": overwrite} self.scorer = scorer - self.store_activations = store_activations + self.set_store_activations(store_activations) @property def labels(self): diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index de8375303..ccd49bbac 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -192,7 +192,7 @@ class SpanCategorizer(TrainablePipe): threshold: float = 0.5, max_positive: Optional[int] = None, scorer: Optional[Callable] = spancat_score, - store_activations=False, + store_activations: Union[bool, List[str]] = False, ) -> None: """Initialize the span categorizer. vocab (Vocab): The shared vocabulary. @@ -225,7 +225,7 @@ class SpanCategorizer(TrainablePipe): self.model = model self.name = name self.scorer = scorer - self.store_activations = store_activations + self.set_store_activations(store_activations) @property def key(self) -> str: diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 3430070da..498b3de08 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -97,7 +97,7 @@ class Tagger(TrainablePipe): overwrite=BACKWARD_OVERWRITE, scorer=tagger_score, neg_prefix="!", - store_activations=False, + store_activations: Union[bool, List[str]] = False, ): """Initialize a part-of-speech tagger. @@ -119,7 +119,7 @@ class Tagger(TrainablePipe): cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix} self.cfg = dict(sorted(cfg.items())) self.scorer = scorer - self.store_activations = store_activations + self.set_store_activations(store_activations) @property def labels(self): diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index bd35cfad5..1ca112060 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -148,7 +148,7 @@ class TextCategorizer(TrainablePipe): *, threshold: float, scorer: Optional[Callable] = textcat_score, - store_activations=False, + store_activations: Union[bool, List[str]] = False, ) -> None: """Initialize a text categorizer for single-label classification. @@ -169,7 +169,7 @@ class TextCategorizer(TrainablePipe): cfg = {"labels": [], "threshold": threshold, "positive_label": None} self.cfg = dict(cfg) self.scorer = scorer - self.store_activations = store_activations + self.set_store_activations(store_activations) @property def support_missing_values(self): diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index 620a12f02..db523d024 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): *, threshold: float, scorer: Optional[Callable] = textcat_multilabel_score, - store_activations=False, + store_activations: Union[bool, List[str]] = False, ) -> None: """Initialize a text categorizer for multi-label classification. @@ -167,7 +167,7 @@ class MultiLabel_TextCategorizer(TextCategorizer): cfg = {"labels": [], "threshold": threshold} self.cfg = dict(cfg) self.scorer = scorer - self.store_activations = store_activations + self.set_store_activations(store_activations) @property def support_missing_values(self): diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx index 23d65f0c7..69fc6ca4f 100644 --- a/spacy/pipeline/trainable_pipe.pyx +++ b/spacy/pipeline/trainable_pipe.pyx @@ -352,8 +352,7 @@ cdef class TrainablePipe(Pipe): def store_activations(self): return self._store_activations - @store_activations.setter - def store_activations(self, activations): + def set_store_activations(self, activations): known_activations = self.activations if isinstance(activations, list): self._store_activations = [] diff --git a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py index 8af56756e..5905b4583 100644 --- a/spacy/tests/pipeline/test_edit_tree_lemmatizer.py +++ b/spacy/tests/pipeline/test_edit_tree_lemmatizer.py @@ -295,13 +295,13 @@ def test_store_activations(): doc = nlp("This is a test.") assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0 - lemmatizer.store_activations = True + lemmatizer.set_store_activations(True) doc = nlp("This is a test.") assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"] assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO) assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,) - lemmatizer.store_activations = ["probs"] + lemmatizer.set_store_activations(["probs"]) doc = nlp("This is a test.") assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"] assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 82a9750dd..2e5ae3621 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1227,7 +1227,7 @@ def test_store_activations(): doc = nlp("Russ Cochran was a publisher") assert len(doc.activations["entity_linker"].keys()) == 0 - entity_linker.store_activations = True + entity_linker.set_store_activations(True) doc = nlp("Russ Cochran was a publisher") assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"} ents = doc.activations["entity_linker"]["ents"] @@ -1241,7 +1241,7 @@ def test_store_activations(): assert scores.data.dtype == "float32" assert scores.lengths.shape == (1,) - entity_linker.store_activations = ["scores"] + entity_linker.set_store_activations(["scores"]) doc = nlp("Russ Cochran was a publisher") assert set(doc.activations["entity_linker"].keys()) == {"scores"} scores = doc.activations["entity_linker"]["scores"] diff --git a/spacy/tests/pipeline/test_morphologizer.py b/spacy/tests/pipeline/test_morphologizer.py index fc2f18730..a5907d8cf 100644 --- a/spacy/tests/pipeline/test_morphologizer.py +++ b/spacy/tests/pipeline/test_morphologizer.py @@ -213,14 +213,14 @@ def test_store_activations(): doc = nlp("This is a test.") assert len(list(doc.activations["morphologizer"].keys())) == 0 - morphologizer.store_activations = True + morphologizer.set_store_activations(True) doc = nlp("This is a test.") assert "morphologizer" in doc.activations assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"} assert doc.activations["morphologizer"]["probs"].shape == (5, 6) assert doc.activations["morphologizer"]["guesses"].shape == (5,) - morphologizer.store_activations = ["probs"] + morphologizer.set_store_activations(["probs"]) doc = nlp("This is a test.") assert "morphologizer" in doc.activations assert set(doc.activations["morphologizer"].keys()) == {"probs"} diff --git a/spacy/tests/pipeline/test_senter.py b/spacy/tests/pipeline/test_senter.py index db8d220bd..2a9fc77f3 100644 --- a/spacy/tests/pipeline/test_senter.py +++ b/spacy/tests/pipeline/test_senter.py @@ -120,14 +120,14 @@ def test_store_activations(): doc = nlp("This is a test.") assert len(list(doc.activations["senter"].keys())) == 0 - senter.store_activations = True + senter.set_store_activations(True) doc = nlp("This is a test.") assert "senter" in doc.activations assert set(doc.activations["senter"].keys()) == {"guesses", "probs"} assert doc.activations["senter"]["probs"].shape == (5, nO) assert doc.activations["senter"]["guesses"].shape == (5,) - senter.store_activations = ["probs"] + senter.set_store_activations(["probs"]) doc = nlp("This is a test.") assert "senter" in doc.activations assert set(doc.activations["senter"].keys()) == {"probs"} diff --git a/spacy/tests/pipeline/test_spancat.py b/spacy/tests/pipeline/test_spancat.py index d88de1d09..e34a6e69f 100644 --- a/spacy/tests/pipeline/test_spancat.py +++ b/spacy/tests/pipeline/test_spancat.py @@ -434,14 +434,13 @@ def test_store_activations(): doc = nlp("This is a test.") assert len(list(doc.activations["spancat"].keys())) == 0 - spancat.store_activations = True + spancat.set_store_activations(True) doc = nlp("This is a test.") assert set(doc.activations["spancat"].keys()) == {"indices", "scores"} assert doc.activations["spancat"]["indices"].shape == (12, 2) assert doc.activations["spancat"]["scores"].shape == (12, nO) - spancat.store_activations = True - spancat.store_activations = ["scores"] + spancat.set_store_activations(["scores"]) doc = nlp("This is a test.") assert set(doc.activations["spancat"].keys()) == {"scores"} assert doc.activations["spancat"]["scores"].shape == (12, nO) diff --git a/spacy/tests/pipeline/test_tagger.py b/spacy/tests/pipeline/test_tagger.py index e1aef969a..fca8bc262 100644 --- a/spacy/tests/pipeline/test_tagger.py +++ b/spacy/tests/pipeline/test_tagger.py @@ -225,14 +225,14 @@ def test_store_activations(): doc = nlp("This is a test.") assert len(list(doc.activations["tagger"].keys())) == 0 - tagger.store_activations = True + tagger.set_store_activations(True) doc = nlp("This is a test.") assert "tagger" in doc.activations assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"} assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) assert doc.activations["tagger"]["guesses"].shape == (5,) - tagger.store_activations = ["probs"] + tagger.set_store_activations(["probs"]) doc = nlp("This is a test.") assert set(doc.activations["tagger"].keys()) == {"probs"} assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 79d0945a5..5024c1fd5 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -888,12 +888,12 @@ def test_store_activations(): doc = nlp("This is a test.") assert len(list(doc.activations["textcat"].keys())) == 0 - textcat.store_activations = True + textcat.set_store_activations(True) doc = nlp("This is a test.") assert list(doc.activations["textcat"].keys()) == ["probs"] assert doc.activations["textcat"]["probs"].shape == (nO,) - textcat.store_activations = ["probs"] + textcat.set_store_activations(["probs"]) doc = nlp("This is a test.") assert list(doc.activations["textcat"].keys()) == ["probs"] assert doc.activations["textcat"]["probs"].shape == (nO,) @@ -913,12 +913,12 @@ def test_store_activations_multi(): doc = nlp("This is a test.") assert len(list(doc.activations["textcat_multilabel"].keys())) == 0 - textcat.store_activations = True + textcat.set_store_activations(True) doc = nlp("This is a test.") assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"] assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,) - textcat.store_activations = ["probs"] + textcat.set_store_activations(["probs"]) doc = nlp("This is a test.") assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"] assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)