Replace store_activations setter by set_store_activations method

Setters that take a different type than what the getter returns are still
problematic for MyPy. Replace the setter by a method, so that type inference
works everywhere.
This commit is contained in:
Daniël de Kok 2022-08-04 15:11:40 +02:00
parent 288d27e17e
commit 51f72e41ec
16 changed files with 33 additions and 35 deletions

View File

@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
overwrite: bool = False,
top_k: int = 1,
scorer: Optional[Callable] = lemmatizer_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""
Construct an edit tree lemmatizer.
@ -125,7 +125,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer
self.store_activations = store_activations # type: ignore
self.set_store_activations(store_activations)
def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d]

View File

@ -174,7 +174,7 @@ class EntityLinker(TrainablePipe):
scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool,
threshold: Optional[float] = None,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize an entity linker.
@ -223,7 +223,7 @@ class EntityLinker(TrainablePipe):
self.scorer = scorer
self.use_gold_ents = use_gold_ents
self.threshold = threshold
self.store_activations = store_activations
self.set_store_activations(store_activations)
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will

View File

@ -103,7 +103,7 @@ class Morphologizer(Tagger):
overwrite: bool = BACKWARD_OVERWRITE,
extend: bool = BACKWARD_EXTEND,
scorer: Optional[Callable] = morphologizer_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""Initialize a morphologizer.
@ -135,7 +135,7 @@ class Morphologizer(Tagger):
}
self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)
@property
def labels(self):

View File

@ -82,7 +82,7 @@ class SentenceRecognizer(Tagger):
*,
overwrite=BACKWARD_OVERWRITE,
scorer=senter_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""Initialize a sentence recognizer.
@ -103,7 +103,7 @@ class SentenceRecognizer(Tagger):
self._rehearsal_model = None
self.cfg = {"overwrite": overwrite}
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)
@property
def labels(self):

View File

@ -192,7 +192,7 @@ class SpanCategorizer(TrainablePipe):
threshold: float = 0.5,
max_positive: Optional[int] = None,
scorer: Optional[Callable] = spancat_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize the span categorizer.
vocab (Vocab): The shared vocabulary.
@ -225,7 +225,7 @@ class SpanCategorizer(TrainablePipe):
self.model = model
self.name = name
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)
@property
def key(self) -> str:

View File

@ -97,7 +97,7 @@ class Tagger(TrainablePipe):
overwrite=BACKWARD_OVERWRITE,
scorer=tagger_score,
neg_prefix="!",
store_activations=False,
store_activations: Union[bool, List[str]] = False,
):
"""Initialize a part-of-speech tagger.
@ -119,7 +119,7 @@ class Tagger(TrainablePipe):
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)
@property
def labels(self):

View File

@ -148,7 +148,7 @@ class TextCategorizer(TrainablePipe):
*,
threshold: float,
scorer: Optional[Callable] = textcat_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize a text categorizer for single-label classification.
@ -169,7 +169,7 @@ class TextCategorizer(TrainablePipe):
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg)
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)
@property
def support_missing_values(self):

View File

@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
*,
threshold: float,
scorer: Optional[Callable] = textcat_multilabel_score,
store_activations=False,
store_activations: Union[bool, List[str]] = False,
) -> None:
"""Initialize a text categorizer for multi-label classification.
@ -167,7 +167,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg)
self.scorer = scorer
self.store_activations = store_activations
self.set_store_activations(store_activations)
@property
def support_missing_values(self):

View File

@ -352,8 +352,7 @@ cdef class TrainablePipe(Pipe):
def store_activations(self):
return self._store_activations
@store_activations.setter
def store_activations(self, activations):
def set_store_activations(self, activations):
known_activations = self.activations
if isinstance(activations, list):
self._store_activations = []

View File

@ -295,13 +295,13 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
lemmatizer.store_activations = True
lemmatizer.set_store_activations(True)
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
lemmatizer.store_activations = ["probs"]
lemmatizer.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)

View File

@ -1227,7 +1227,7 @@ def test_store_activations():
doc = nlp("Russ Cochran was a publisher")
assert len(doc.activations["entity_linker"].keys()) == 0
entity_linker.store_activations = True
entity_linker.set_store_activations(True)
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
ents = doc.activations["entity_linker"]["ents"]
@ -1241,7 +1241,7 @@ def test_store_activations():
assert scores.data.dtype == "float32"
assert scores.lengths.shape == (1,)
entity_linker.store_activations = ["scores"]
entity_linker.set_store_activations(["scores"])
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
scores = doc.activations["entity_linker"]["scores"]

View File

@ -213,14 +213,14 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["morphologizer"].keys())) == 0
morphologizer.store_activations = True
morphologizer.set_store_activations(True)
doc = nlp("This is a test.")
assert "morphologizer" in doc.activations
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
assert doc.activations["morphologizer"]["guesses"].shape == (5,)
morphologizer.store_activations = ["probs"]
morphologizer.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert "morphologizer" in doc.activations
assert set(doc.activations["morphologizer"].keys()) == {"probs"}

View File

@ -120,14 +120,14 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["senter"].keys())) == 0
senter.store_activations = True
senter.set_store_activations(True)
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
assert doc.activations["senter"]["probs"].shape == (5, nO)
assert doc.activations["senter"]["guesses"].shape == (5,)
senter.store_activations = ["probs"]
senter.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"probs"}

View File

@ -434,14 +434,13 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["spancat"].keys())) == 0
spancat.store_activations = True
spancat.set_store_activations(True)
doc = nlp("This is a test.")
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
assert doc.activations["spancat"]["indices"].shape == (12, 2)
assert doc.activations["spancat"]["scores"].shape == (12, nO)
spancat.store_activations = True
spancat.store_activations = ["scores"]
spancat.set_store_activations(["scores"])
doc = nlp("This is a test.")
assert set(doc.activations["spancat"].keys()) == {"scores"}
assert doc.activations["spancat"]["scores"].shape == (12, nO)

View File

@ -225,14 +225,14 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["tagger"].keys())) == 0
tagger.store_activations = True
tagger.set_store_activations(True)
doc = nlp("This is a test.")
assert "tagger" in doc.activations
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
assert doc.activations["tagger"]["guesses"].shape == (5,)
tagger.store_activations = ["probs"]
tagger.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert set(doc.activations["tagger"].keys()) == {"probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))

View File

@ -888,12 +888,12 @@ def test_store_activations():
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat"].keys())) == 0
textcat.store_activations = True
textcat.set_store_activations(True)
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
textcat.store_activations = ["probs"]
textcat.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
@ -913,12 +913,12 @@ def test_store_activations_multi():
doc = nlp("This is a test.")
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
textcat.store_activations = True
textcat.set_store_activations(True)
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
textcat.store_activations = ["probs"]
textcat.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)