mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-04 06:16:33 +03:00
Replace store_activations setter by set_store_activations method
Setters that take a different type than what the getter returns are still problematic for MyPy. Replace the setter by a method, so that type inference works everywhere.
This commit is contained in:
parent
288d27e17e
commit
51f72e41ec
|
@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
overwrite: bool = False,
|
overwrite: bool = False,
|
||||||
top_k: int = 1,
|
top_k: int = 1,
|
||||||
scorer: Optional[Callable] = lemmatizer_score,
|
scorer: Optional[Callable] = lemmatizer_score,
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Construct an edit tree lemmatizer.
|
Construct an edit tree lemmatizer.
|
||||||
|
@ -125,7 +125,7 @@ class EditTreeLemmatizer(TrainablePipe):
|
||||||
|
|
||||||
self.cfg: Dict[str, Any] = {"labels": []}
|
self.cfg: Dict[str, Any] = {"labels": []}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.store_activations = store_activations # type: ignore
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
def get_loss(
|
def get_loss(
|
||||||
self, examples: Iterable[Example], scores: List[Floats2d]
|
self, examples: Iterable[Example], scores: List[Floats2d]
|
||||||
|
|
|
@ -174,7 +174,7 @@ class EntityLinker(TrainablePipe):
|
||||||
scorer: Optional[Callable] = entity_linker_score,
|
scorer: Optional[Callable] = entity_linker_score,
|
||||||
use_gold_ents: bool,
|
use_gold_ents: bool,
|
||||||
threshold: Optional[float] = None,
|
threshold: Optional[float] = None,
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize an entity linker.
|
"""Initialize an entity linker.
|
||||||
|
|
||||||
|
@ -223,7 +223,7 @@ class EntityLinker(TrainablePipe):
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.use_gold_ents = use_gold_ents
|
self.use_gold_ents = use_gold_ents
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
self.store_activations = store_activations
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
||||||
"""Define the KB of this pipe by providing a function that will
|
"""Define the KB of this pipe by providing a function that will
|
||||||
|
|
|
@ -103,7 +103,7 @@ class Morphologizer(Tagger):
|
||||||
overwrite: bool = BACKWARD_OVERWRITE,
|
overwrite: bool = BACKWARD_OVERWRITE,
|
||||||
extend: bool = BACKWARD_EXTEND,
|
extend: bool = BACKWARD_EXTEND,
|
||||||
scorer: Optional[Callable] = morphologizer_score,
|
scorer: Optional[Callable] = morphologizer_score,
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
):
|
):
|
||||||
"""Initialize a morphologizer.
|
"""Initialize a morphologizer.
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ class Morphologizer(Tagger):
|
||||||
}
|
}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.store_activations = store_activations
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
|
|
@ -82,7 +82,7 @@ class SentenceRecognizer(Tagger):
|
||||||
*,
|
*,
|
||||||
overwrite=BACKWARD_OVERWRITE,
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
scorer=senter_score,
|
scorer=senter_score,
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
):
|
):
|
||||||
"""Initialize a sentence recognizer.
|
"""Initialize a sentence recognizer.
|
||||||
|
|
||||||
|
@ -103,7 +103,7 @@ class SentenceRecognizer(Tagger):
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
self.cfg = {"overwrite": overwrite}
|
self.cfg = {"overwrite": overwrite}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.store_activations = store_activations
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
|
|
@ -192,7 +192,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
threshold: float = 0.5,
|
threshold: float = 0.5,
|
||||||
max_positive: Optional[int] = None,
|
max_positive: Optional[int] = None,
|
||||||
scorer: Optional[Callable] = spancat_score,
|
scorer: Optional[Callable] = spancat_score,
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the span categorizer.
|
"""Initialize the span categorizer.
|
||||||
vocab (Vocab): The shared vocabulary.
|
vocab (Vocab): The shared vocabulary.
|
||||||
|
@ -225,7 +225,7 @@ class SpanCategorizer(TrainablePipe):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.name = name
|
self.name = name
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.store_activations = store_activations
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
|
|
|
@ -97,7 +97,7 @@ class Tagger(TrainablePipe):
|
||||||
overwrite=BACKWARD_OVERWRITE,
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
scorer=tagger_score,
|
scorer=tagger_score,
|
||||||
neg_prefix="!",
|
neg_prefix="!",
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
):
|
):
|
||||||
"""Initialize a part-of-speech tagger.
|
"""Initialize a part-of-speech tagger.
|
||||||
|
|
||||||
|
@ -119,7 +119,7 @@ class Tagger(TrainablePipe):
|
||||||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.store_activations = store_activations
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
|
|
@ -148,7 +148,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
*,
|
*,
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable] = textcat_score,
|
scorer: Optional[Callable] = textcat_score,
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer for single-label classification.
|
"""Initialize a text categorizer for single-label classification.
|
||||||
|
|
||||||
|
@ -169,7 +169,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
cfg = {"labels": [], "threshold": threshold, "positive_label": None}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.store_activations = store_activations
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def support_missing_values(self):
|
def support_missing_values(self):
|
||||||
|
|
|
@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
*,
|
*,
|
||||||
threshold: float,
|
threshold: float,
|
||||||
scorer: Optional[Callable] = textcat_multilabel_score,
|
scorer: Optional[Callable] = textcat_multilabel_score,
|
||||||
store_activations=False,
|
store_activations: Union[bool, List[str]] = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer for multi-label classification.
|
"""Initialize a text categorizer for multi-label classification.
|
||||||
|
|
||||||
|
@ -167,7 +167,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
cfg = {"labels": [], "threshold": threshold}
|
cfg = {"labels": [], "threshold": threshold}
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
self.store_activations = store_activations
|
self.set_store_activations(store_activations)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def support_missing_values(self):
|
def support_missing_values(self):
|
||||||
|
|
|
@ -352,8 +352,7 @@ cdef class TrainablePipe(Pipe):
|
||||||
def store_activations(self):
|
def store_activations(self):
|
||||||
return self._store_activations
|
return self._store_activations
|
||||||
|
|
||||||
@store_activations.setter
|
def set_store_activations(self, activations):
|
||||||
def store_activations(self, activations):
|
|
||||||
known_activations = self.activations
|
known_activations = self.activations
|
||||||
if isinstance(activations, list):
|
if isinstance(activations, list):
|
||||||
self._store_activations = []
|
self._store_activations = []
|
||||||
|
|
|
@ -295,13 +295,13 @@ def test_store_activations():
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
|
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
|
||||||
|
|
||||||
lemmatizer.store_activations = True
|
lemmatizer.set_store_activations(True)
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
|
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
|
||||||
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
||||||
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
|
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
lemmatizer.store_activations = ["probs"]
|
lemmatizer.set_store_activations(["probs"])
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
|
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
|
||||||
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
|
||||||
|
|
|
@ -1227,7 +1227,7 @@ def test_store_activations():
|
||||||
doc = nlp("Russ Cochran was a publisher")
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
assert len(doc.activations["entity_linker"].keys()) == 0
|
assert len(doc.activations["entity_linker"].keys()) == 0
|
||||||
|
|
||||||
entity_linker.store_activations = True
|
entity_linker.set_store_activations(True)
|
||||||
doc = nlp("Russ Cochran was a publisher")
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
|
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
|
||||||
ents = doc.activations["entity_linker"]["ents"]
|
ents = doc.activations["entity_linker"]["ents"]
|
||||||
|
@ -1241,7 +1241,7 @@ def test_store_activations():
|
||||||
assert scores.data.dtype == "float32"
|
assert scores.data.dtype == "float32"
|
||||||
assert scores.lengths.shape == (1,)
|
assert scores.lengths.shape == (1,)
|
||||||
|
|
||||||
entity_linker.store_activations = ["scores"]
|
entity_linker.set_store_activations(["scores"])
|
||||||
doc = nlp("Russ Cochran was a publisher")
|
doc = nlp("Russ Cochran was a publisher")
|
||||||
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
|
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
|
||||||
scores = doc.activations["entity_linker"]["scores"]
|
scores = doc.activations["entity_linker"]["scores"]
|
||||||
|
|
|
@ -213,14 +213,14 @@ def test_store_activations():
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["morphologizer"].keys())) == 0
|
assert len(list(doc.activations["morphologizer"].keys())) == 0
|
||||||
|
|
||||||
morphologizer.store_activations = True
|
morphologizer.set_store_activations(True)
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "morphologizer" in doc.activations
|
assert "morphologizer" in doc.activations
|
||||||
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
|
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
|
||||||
assert doc.activations["morphologizer"]["guesses"].shape == (5,)
|
assert doc.activations["morphologizer"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
morphologizer.store_activations = ["probs"]
|
morphologizer.set_store_activations(["probs"])
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "morphologizer" in doc.activations
|
assert "morphologizer" in doc.activations
|
||||||
assert set(doc.activations["morphologizer"].keys()) == {"probs"}
|
assert set(doc.activations["morphologizer"].keys()) == {"probs"}
|
||||||
|
|
|
@ -120,14 +120,14 @@ def test_store_activations():
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["senter"].keys())) == 0
|
assert len(list(doc.activations["senter"].keys())) == 0
|
||||||
|
|
||||||
senter.store_activations = True
|
senter.set_store_activations(True)
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "senter" in doc.activations
|
assert "senter" in doc.activations
|
||||||
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["senter"]["probs"].shape == (5, nO)
|
assert doc.activations["senter"]["probs"].shape == (5, nO)
|
||||||
assert doc.activations["senter"]["guesses"].shape == (5,)
|
assert doc.activations["senter"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
senter.store_activations = ["probs"]
|
senter.set_store_activations(["probs"])
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "senter" in doc.activations
|
assert "senter" in doc.activations
|
||||||
assert set(doc.activations["senter"].keys()) == {"probs"}
|
assert set(doc.activations["senter"].keys()) == {"probs"}
|
||||||
|
|
|
@ -434,14 +434,13 @@ def test_store_activations():
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["spancat"].keys())) == 0
|
assert len(list(doc.activations["spancat"].keys())) == 0
|
||||||
|
|
||||||
spancat.store_activations = True
|
spancat.set_store_activations(True)
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
|
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
|
||||||
assert doc.activations["spancat"]["indices"].shape == (12, 2)
|
assert doc.activations["spancat"]["indices"].shape == (12, 2)
|
||||||
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
||||||
spancat.store_activations = True
|
|
||||||
|
|
||||||
spancat.store_activations = ["scores"]
|
spancat.set_store_activations(["scores"])
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert set(doc.activations["spancat"].keys()) == {"scores"}
|
assert set(doc.activations["spancat"].keys()) == {"scores"}
|
||||||
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
assert doc.activations["spancat"]["scores"].shape == (12, nO)
|
||||||
|
|
|
@ -225,14 +225,14 @@ def test_store_activations():
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["tagger"].keys())) == 0
|
assert len(list(doc.activations["tagger"].keys())) == 0
|
||||||
|
|
||||||
tagger.store_activations = True
|
tagger.set_store_activations(True)
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert "tagger" in doc.activations
|
assert "tagger" in doc.activations
|
||||||
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
||||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||||
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
tagger.store_activations = ["probs"]
|
tagger.set_store_activations(["probs"])
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert set(doc.activations["tagger"].keys()) == {"probs"}
|
assert set(doc.activations["tagger"].keys()) == {"probs"}
|
||||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||||
|
|
|
@ -888,12 +888,12 @@ def test_store_activations():
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["textcat"].keys())) == 0
|
assert len(list(doc.activations["textcat"].keys())) == 0
|
||||||
|
|
||||||
textcat.store_activations = True
|
textcat.set_store_activations(True)
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||||
|
|
||||||
textcat.store_activations = ["probs"]
|
textcat.set_store_activations(["probs"])
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
assert list(doc.activations["textcat"].keys()) == ["probs"]
|
||||||
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
assert doc.activations["textcat"]["probs"].shape == (nO,)
|
||||||
|
@ -913,12 +913,12 @@ def test_store_activations_multi():
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
|
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
|
||||||
|
|
||||||
textcat.store_activations = True
|
textcat.set_store_activations(True)
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||||
|
|
||||||
textcat.store_activations = ["probs"]
|
textcat.set_store_activations(["probs"])
|
||||||
doc = nlp("This is a test.")
|
doc = nlp("This is a test.")
|
||||||
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
|
||||||
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user