mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-29 11:26:28 +03:00
Replace store_activations setter by set_store_activations method
Setters that take a different type than what the getter returns are still problematic for MyPy. Replace the setter by a method, so that type inference works everywhere.
This commit is contained in:
parent
288d27e17e
commit
51f72e41ec
|
@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
overwrite: bool = False,
|
||||
top_k: int = 1,
|
||||
scorer: Optional[Callable] = lemmatizer_score,
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
):
|
||||
"""
|
||||
Construct an edit tree lemmatizer.
|
||||
|
@ -125,7 +125,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
|||
|
||||
self.cfg: Dict[str, Any] = {"labels": []}
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations # type: ignore
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
def get_loss(
|
||||
self, examples: Iterable[Example], scores: List[Floats2d]
|
||||
|
|
|
@ -174,7 +174,7 @@ class EntityLinker(TrainablePipe):
|
|||
scorer: Optional[Callable] = entity_linker_score,
|
||||
use_gold_ents: bool,
|
||||
threshold: Optional[float] = None,
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
) -> None:
|
||||
"""Initialize an entity linker.
|
||||
|
||||
|
@ -223,7 +223,7 @@ class EntityLinker(TrainablePipe):
|
|||
self.scorer = scorer
|
||||
self.use_gold_ents = use_gold_ents
|
||||
self.threshold = threshold
|
||||
self.store_activations = store_activations
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||
"""Define the KB of this pipe by providing a function that will
|
||||
|
|
|
@ -103,7 +103,7 @@ class Morphologizer(Tagger):
|
|||
overwrite: bool = BACKWARD_OVERWRITE,
|
||||
extend: bool = BACKWARD_EXTEND,
|
||||
scorer: Optional[Callable] = morphologizer_score,
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
):
|
||||
"""Initialize a morphologizer.
|
||||
|
||||
|
@ -135,7 +135,7 @@ class Morphologizer(Tagger):
|
|||
}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
|
|
@ -82,7 +82,7 @@ class SentenceRecognizer(Tagger):
|
|||
*,
|
||||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=senter_score,
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
):
|
||||
"""Initialize a sentence recognizer.
|
||||
|
||||
|
@ -103,7 +103,7 @@ class SentenceRecognizer(Tagger):
|
|||
self._rehearsal_model = None
|
||||
self.cfg = {"overwrite": overwrite}
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
|
|
@ -192,7 +192,7 @@ class SpanCategorizer(TrainablePipe):
|
|||
threshold: float = 0.5,
|
||||
max_positive: Optional[int] = None,
|
||||
scorer: Optional[Callable] = spancat_score,
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
) -> None:
|
||||
"""Initialize the span categorizer.
|
||||
vocab (Vocab): The shared vocabulary.
|
||||
|
@ -225,7 +225,7 @@ class SpanCategorizer(TrainablePipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
|
|
|
@ -97,7 +97,7 @@ class Tagger(TrainablePipe):
|
|||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=tagger_score,
|
||||
neg_prefix="!",
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
):
|
||||
"""Initialize a part-of-speech tagger.
|
||||
|
||||
|
@ -119,7 +119,7 @@ class Tagger(TrainablePipe):
|
|||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
|
|
@ -148,7 +148,7 @@ class TextCategorizer(TrainablePipe):
|
|||
*,
|
||||
threshold: float,
|
||||
scorer: Optional[Callable] = textcat_score,
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
) -> None:
|
||||
"""Initialize a text categorizer for single-label classification.
|
||||
|
||||
|
@ -169,7 +169,7 @@ class TextCategorizer(TrainablePipe):
|
|||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||
self.cfg = dict(cfg)
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
@property
|
||||
def support_missing_values(self):
|
||||
|
|
|
@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
*,
|
||||
threshold: float,
|
||||
scorer: Optional[Callable] = textcat_multilabel_score,
|
||||
store_activations=False,
|
||||
store_activations: Union[bool, List[str]] = False,
|
||||
) -> None:
|
||||
"""Initialize a text categorizer for multi-label classification.
|
||||
|
||||
|
@ -167,7 +167,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
|||
cfg = {"labels": [], "threshold": threshold}
|
||||
self.cfg = dict(cfg)
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
self.set_store_activations(store_activations)
|
||||
|
||||
@property
|
||||
def support_missing_values(self):
|
||||
|
|
|
@ -352,8 +352,7 @@ cdef class TrainablePipe(Pipe):
|
|||
def store_activations(self):
|
||||
return self._store_activations
|
||||
|
||||
@store_activations.setter
|
||||
def store_activations(self, activations):
|
||||
def set_store_activations(self, activations):
|
||||
known_activations = self.activations
|
||||
if isinstance(activations, list):
|
||||
self._store_activations = []
|
||||
|
|
|
@ -295,13 +295,13 @@ def test_store_activations():
|
|||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
|
||||
|
||||
lemmatizer.store_activations = True
|
||||
lemmatizer.set_store_activations(True)
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
|
||||
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
||||
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
|
||||
|
||||
lemmatizer.store_activations = ["probs"]
|
||||
lemmatizer.set_store_activations(["probs"])
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
|
||||
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
||||
|
|
|
@ -1227,7 +1227,7 @@ def test_store_activations():
|
|||
doc = nlp("Russ Cochran was a publisher")
|
||||
assert len(doc.activations["entity_linker"].keys()) == 0
|
||||
|
||||
entity_linker.store_activations = True
|
||||
entity_linker.set_store_activations(True)
|
||||
doc = nlp("Russ Cochran was a publisher")
|
||||
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
|
||||
ents = doc.activations["entity_linker"]["ents"]
|
||||
|
@ -1241,7 +1241,7 @@ def test_store_activations():
|
|||
assert scores.data.dtype == "float32"
|
||||
assert scores.lengths.shape == (1,)
|
||||
|
||||
entity_linker.store_activations = ["scores"]
|
||||
entity_linker.set_store_activations(["scores"])
|
||||
doc = nlp("Russ Cochran was a publisher")
|
||||
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
|
||||
scores = doc.activations["entity_linker"]["scores"]
|
||||
|
|
|
@ -213,14 +213,14 @@ def test_store_activations():
|
|||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["morphologizer"].keys())) == 0
|
||||
|
||||
morphologizer.store_activations = True
|
||||
morphologizer.set_store_activations(True)
|
||||
doc = nlp("This is a test.")
|
||||
assert "morphologizer" in doc.activations
|
||||
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
|
||||
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
|
||||
assert doc.activations["morphologizer"]["guesses"].shape == (5,)
|
||||
|
||||
morphologizer.store_activations = ["probs"]
|
||||
morphologizer.set_store_activations(["probs"])
|
||||
doc = nlp("This is a test.")
|
||||
assert "morphologizer" in doc.activations
|
||||
assert set(doc.activations["morphologizer"].keys()) == {"probs"}
|
||||
|
|
|
@ -120,14 +120,14 @@ def test_store_activations():
|
|||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["senter"].keys())) == 0
|
||||
|
||||
senter.store_activations = True
|
||||
senter.set_store_activations(True)
|
||||
doc = nlp("This is a test.")
|
||||
assert "senter" in doc.activations
|
||||
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
||||
assert doc.activations["senter"]["probs"].shape == (5, nO)
|
||||
assert doc.activations["senter"]["guesses"].shape == (5,)
|
||||
|
||||
senter.store_activations = ["probs"]
|
||||
senter.set_store_activations(["probs"])
|
||||
doc = nlp("This is a test.")
|
||||
assert "senter" in doc.activations
|
||||
assert set(doc.activations["senter"].keys()) == {"probs"}
|
||||
|
|
|
@ -434,14 +434,13 @@ def test_store_activations():
|
|||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["spancat"].keys())) == 0
|
||||
|
||||
spancat.store_activations = True
|
||||
spancat.set_store_activations(True)
|
||||
doc = nlp("This is a test.")
|
||||
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
|
||||
assert doc.activations["spancat"]["indices"].shape == (12, 2)
|
||||
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
||||
spancat.store_activations = True
|
||||
|
||||
spancat.store_activations = ["scores"]
|
||||
spancat.set_store_activations(["scores"])
|
||||
doc = nlp("This is a test.")
|
||||
assert set(doc.activations["spancat"].keys()) == {"scores"}
|
||||
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
||||
|
|
|
@ -225,14 +225,14 @@ def test_store_activations():
|
|||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["tagger"].keys())) == 0
|
||||
|
||||
tagger.store_activations = True
|
||||
tagger.set_store_activations(True)
|
||||
doc = nlp("This is a test.")
|
||||
assert "tagger" in doc.activations
|
||||
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
||||
|
||||
tagger.store_activations = ["probs"]
|
||||
tagger.set_store_activations(["probs"])
|
||||
doc = nlp("This is a test.")
|
||||
assert set(doc.activations["tagger"].keys()) == {"probs"}
|
||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||
|
|
|
@ -888,12 +888,12 @@ def test_store_activations():
|
|||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["textcat"].keys())) == 0
|
||||
|
||||
textcat.store_activations = True
|
||||
textcat.set_store_activations(True)
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||
|
||||
textcat.store_activations = ["probs"]
|
||||
textcat.set_store_activations(["probs"])
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||
|
@ -913,12 +913,12 @@ def test_store_activations_multi():
|
|||
doc = nlp("This is a test.")
|
||||
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
|
||||
|
||||
textcat.store_activations = True
|
||||
textcat.set_store_activations(True)
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||
|
||||
textcat.store_activations = ["probs"]
|
||||
textcat.set_store_activations(["probs"])
|
||||
doc = nlp("This is a test.")
|
||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||
|
|
Loading…
Reference in New Issue
Block a user