mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Replace store_activations setter by set_store_activations method
Setters that take a different type than what the getter returns are still problematic for MyPy. Replace the setter by a method, so that type inference works everywhere.
This commit is contained in:
		
							parent
							
								
									288d27e17e
								
							
						
					
					
						commit
						51f72e41ec
					
				| 
						 | 
					@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
 | 
				
			||||||
        overwrite: bool = False,
 | 
					        overwrite: bool = False,
 | 
				
			||||||
        top_k: int = 1,
 | 
					        top_k: int = 1,
 | 
				
			||||||
        scorer: Optional[Callable] = lemmatizer_score,
 | 
					        scorer: Optional[Callable] = lemmatizer_score,
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        Construct an edit tree lemmatizer.
 | 
					        Construct an edit tree lemmatizer.
 | 
				
			||||||
| 
						 | 
					@ -125,7 +125,7 @@ class EditTreeLemmatizer(TrainablePipe):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.cfg: Dict[str, Any] = {"labels": []}
 | 
					        self.cfg: Dict[str, Any] = {"labels": []}
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.store_activations = store_activations  # type: ignore
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_loss(
 | 
					    def get_loss(
 | 
				
			||||||
        self, examples: Iterable[Example], scores: List[Floats2d]
 | 
					        self, examples: Iterable[Example], scores: List[Floats2d]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -174,7 +174,7 @@ class EntityLinker(TrainablePipe):
 | 
				
			||||||
        scorer: Optional[Callable] = entity_linker_score,
 | 
					        scorer: Optional[Callable] = entity_linker_score,
 | 
				
			||||||
        use_gold_ents: bool,
 | 
					        use_gold_ents: bool,
 | 
				
			||||||
        threshold: Optional[float] = None,
 | 
					        threshold: Optional[float] = None,
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """Initialize an entity linker.
 | 
					        """Initialize an entity linker.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -223,7 +223,7 @@ class EntityLinker(TrainablePipe):
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.use_gold_ents = use_gold_ents
 | 
					        self.use_gold_ents = use_gold_ents
 | 
				
			||||||
        self.threshold = threshold
 | 
					        self.threshold = threshold
 | 
				
			||||||
        self.store_activations = store_activations
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
 | 
					    def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
 | 
				
			||||||
        """Define the KB of this pipe by providing a function that will
 | 
					        """Define the KB of this pipe by providing a function that will
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -103,7 +103,7 @@ class Morphologizer(Tagger):
 | 
				
			||||||
        overwrite: bool = BACKWARD_OVERWRITE,
 | 
					        overwrite: bool = BACKWARD_OVERWRITE,
 | 
				
			||||||
        extend: bool = BACKWARD_EXTEND,
 | 
					        extend: bool = BACKWARD_EXTEND,
 | 
				
			||||||
        scorer: Optional[Callable] = morphologizer_score,
 | 
					        scorer: Optional[Callable] = morphologizer_score,
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Initialize a morphologizer.
 | 
					        """Initialize a morphologizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -135,7 +135,7 @@ class Morphologizer(Tagger):
 | 
				
			||||||
        }
 | 
					        }
 | 
				
			||||||
        self.cfg = dict(sorted(cfg.items()))
 | 
					        self.cfg = dict(sorted(cfg.items()))
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.store_activations = store_activations
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def labels(self):
 | 
					    def labels(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -82,7 +82,7 @@ class SentenceRecognizer(Tagger):
 | 
				
			||||||
        *,
 | 
					        *,
 | 
				
			||||||
        overwrite=BACKWARD_OVERWRITE,
 | 
					        overwrite=BACKWARD_OVERWRITE,
 | 
				
			||||||
        scorer=senter_score,
 | 
					        scorer=senter_score,
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Initialize a sentence recognizer.
 | 
					        """Initialize a sentence recognizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -103,7 +103,7 @@ class SentenceRecognizer(Tagger):
 | 
				
			||||||
        self._rehearsal_model = None
 | 
					        self._rehearsal_model = None
 | 
				
			||||||
        self.cfg = {"overwrite": overwrite}
 | 
					        self.cfg = {"overwrite": overwrite}
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.store_activations = store_activations
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def labels(self):
 | 
					    def labels(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -192,7 +192,7 @@ class SpanCategorizer(TrainablePipe):
 | 
				
			||||||
        threshold: float = 0.5,
 | 
					        threshold: float = 0.5,
 | 
				
			||||||
        max_positive: Optional[int] = None,
 | 
					        max_positive: Optional[int] = None,
 | 
				
			||||||
        scorer: Optional[Callable] = spancat_score,
 | 
					        scorer: Optional[Callable] = spancat_score,
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """Initialize the span categorizer.
 | 
					        """Initialize the span categorizer.
 | 
				
			||||||
        vocab (Vocab): The shared vocabulary.
 | 
					        vocab (Vocab): The shared vocabulary.
 | 
				
			||||||
| 
						 | 
					@ -225,7 +225,7 @@ class SpanCategorizer(TrainablePipe):
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        self.name = name
 | 
					        self.name = name
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.store_activations = store_activations
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def key(self) -> str:
 | 
					    def key(self) -> str:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -97,7 +97,7 @@ class Tagger(TrainablePipe):
 | 
				
			||||||
        overwrite=BACKWARD_OVERWRITE,
 | 
					        overwrite=BACKWARD_OVERWRITE,
 | 
				
			||||||
        scorer=tagger_score,
 | 
					        scorer=tagger_score,
 | 
				
			||||||
        neg_prefix="!",
 | 
					        neg_prefix="!",
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        """Initialize a part-of-speech tagger.
 | 
					        """Initialize a part-of-speech tagger.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -119,7 +119,7 @@ class Tagger(TrainablePipe):
 | 
				
			||||||
        cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
 | 
					        cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
 | 
				
			||||||
        self.cfg = dict(sorted(cfg.items()))
 | 
					        self.cfg = dict(sorted(cfg.items()))
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.store_activations = store_activations
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def labels(self):
 | 
					    def labels(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -148,7 +148,7 @@ class TextCategorizer(TrainablePipe):
 | 
				
			||||||
        *,
 | 
					        *,
 | 
				
			||||||
        threshold: float,
 | 
					        threshold: float,
 | 
				
			||||||
        scorer: Optional[Callable] = textcat_score,
 | 
					        scorer: Optional[Callable] = textcat_score,
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """Initialize a text categorizer for single-label classification.
 | 
					        """Initialize a text categorizer for single-label classification.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -169,7 +169,7 @@ class TextCategorizer(TrainablePipe):
 | 
				
			||||||
        cfg = {"labels": [], "threshold": threshold, "positive_label": None}
 | 
					        cfg = {"labels": [], "threshold": threshold, "positive_label": None}
 | 
				
			||||||
        self.cfg = dict(cfg)
 | 
					        self.cfg = dict(cfg)
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.store_activations = store_activations
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def support_missing_values(self):
 | 
					    def support_missing_values(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
 | 
				
			||||||
        *,
 | 
					        *,
 | 
				
			||||||
        threshold: float,
 | 
					        threshold: float,
 | 
				
			||||||
        scorer: Optional[Callable] = textcat_multilabel_score,
 | 
					        scorer: Optional[Callable] = textcat_multilabel_score,
 | 
				
			||||||
        store_activations=False,
 | 
					        store_activations: Union[bool, List[str]] = False,
 | 
				
			||||||
    ) -> None:
 | 
					    ) -> None:
 | 
				
			||||||
        """Initialize a text categorizer for multi-label classification.
 | 
					        """Initialize a text categorizer for multi-label classification.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -167,7 +167,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
 | 
				
			||||||
        cfg = {"labels": [], "threshold": threshold}
 | 
					        cfg = {"labels": [], "threshold": threshold}
 | 
				
			||||||
        self.cfg = dict(cfg)
 | 
					        self.cfg = dict(cfg)
 | 
				
			||||||
        self.scorer = scorer
 | 
					        self.scorer = scorer
 | 
				
			||||||
        self.store_activations = store_activations
 | 
					        self.set_store_activations(store_activations)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def support_missing_values(self):
 | 
					    def support_missing_values(self):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -352,8 +352,7 @@ cdef class TrainablePipe(Pipe):
 | 
				
			||||||
    def store_activations(self):
 | 
					    def store_activations(self):
 | 
				
			||||||
        return self._store_activations
 | 
					        return self._store_activations
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @store_activations.setter
 | 
					    def set_store_activations(self, activations):
 | 
				
			||||||
    def store_activations(self, activations):
 | 
					 | 
				
			||||||
        known_activations = self.activations
 | 
					        known_activations = self.activations
 | 
				
			||||||
        if isinstance(activations, list):
 | 
					        if isinstance(activations, list):
 | 
				
			||||||
            self._store_activations = []
 | 
					            self._store_activations = []
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -295,13 +295,13 @@ def test_store_activations():
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
 | 
					    assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lemmatizer.store_activations = True
 | 
					    lemmatizer.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
 | 
					    assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
 | 
				
			||||||
    assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
 | 
					    assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
 | 
				
			||||||
    assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
 | 
					    assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    lemmatizer.store_activations = ["probs"]
 | 
					    lemmatizer.set_store_activations(["probs"])
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
 | 
					    assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
 | 
				
			||||||
    assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
 | 
					    assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1227,7 +1227,7 @@ def test_store_activations():
 | 
				
			||||||
    doc = nlp("Russ Cochran was a publisher")
 | 
					    doc = nlp("Russ Cochran was a publisher")
 | 
				
			||||||
    assert len(doc.activations["entity_linker"].keys()) == 0
 | 
					    assert len(doc.activations["entity_linker"].keys()) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    entity_linker.store_activations = True
 | 
					    entity_linker.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("Russ Cochran was a publisher")
 | 
					    doc = nlp("Russ Cochran was a publisher")
 | 
				
			||||||
    assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
 | 
					    assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
 | 
				
			||||||
    ents = doc.activations["entity_linker"]["ents"]
 | 
					    ents = doc.activations["entity_linker"]["ents"]
 | 
				
			||||||
| 
						 | 
					@ -1241,7 +1241,7 @@ def test_store_activations():
 | 
				
			||||||
    assert scores.data.dtype == "float32"
 | 
					    assert scores.data.dtype == "float32"
 | 
				
			||||||
    assert scores.lengths.shape == (1,)
 | 
					    assert scores.lengths.shape == (1,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    entity_linker.store_activations = ["scores"]
 | 
					    entity_linker.set_store_activations(["scores"])
 | 
				
			||||||
    doc = nlp("Russ Cochran was a publisher")
 | 
					    doc = nlp("Russ Cochran was a publisher")
 | 
				
			||||||
    assert set(doc.activations["entity_linker"].keys()) == {"scores"}
 | 
					    assert set(doc.activations["entity_linker"].keys()) == {"scores"}
 | 
				
			||||||
    scores = doc.activations["entity_linker"]["scores"]
 | 
					    scores = doc.activations["entity_linker"]["scores"]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -213,14 +213,14 @@ def test_store_activations():
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert len(list(doc.activations["morphologizer"].keys())) == 0
 | 
					    assert len(list(doc.activations["morphologizer"].keys())) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    morphologizer.store_activations = True
 | 
					    morphologizer.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert "morphologizer" in doc.activations
 | 
					    assert "morphologizer" in doc.activations
 | 
				
			||||||
    assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
 | 
					    assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
 | 
				
			||||||
    assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
 | 
					    assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
 | 
				
			||||||
    assert doc.activations["morphologizer"]["guesses"].shape == (5,)
 | 
					    assert doc.activations["morphologizer"]["guesses"].shape == (5,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    morphologizer.store_activations = ["probs"]
 | 
					    morphologizer.set_store_activations(["probs"])
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert "morphologizer" in doc.activations
 | 
					    assert "morphologizer" in doc.activations
 | 
				
			||||||
    assert set(doc.activations["morphologizer"].keys()) == {"probs"}
 | 
					    assert set(doc.activations["morphologizer"].keys()) == {"probs"}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -120,14 +120,14 @@ def test_store_activations():
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert len(list(doc.activations["senter"].keys())) == 0
 | 
					    assert len(list(doc.activations["senter"].keys())) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    senter.store_activations = True
 | 
					    senter.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert "senter" in doc.activations
 | 
					    assert "senter" in doc.activations
 | 
				
			||||||
    assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
 | 
					    assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
 | 
				
			||||||
    assert doc.activations["senter"]["probs"].shape == (5, nO)
 | 
					    assert doc.activations["senter"]["probs"].shape == (5, nO)
 | 
				
			||||||
    assert doc.activations["senter"]["guesses"].shape == (5,)
 | 
					    assert doc.activations["senter"]["guesses"].shape == (5,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    senter.store_activations = ["probs"]
 | 
					    senter.set_store_activations(["probs"])
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert "senter" in doc.activations
 | 
					    assert "senter" in doc.activations
 | 
				
			||||||
    assert set(doc.activations["senter"].keys()) == {"probs"}
 | 
					    assert set(doc.activations["senter"].keys()) == {"probs"}
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -434,14 +434,13 @@ def test_store_activations():
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert len(list(doc.activations["spancat"].keys())) == 0
 | 
					    assert len(list(doc.activations["spancat"].keys())) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    spancat.store_activations = True
 | 
					    spancat.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
 | 
					    assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
 | 
				
			||||||
    assert doc.activations["spancat"]["indices"].shape == (12, 2)
 | 
					    assert doc.activations["spancat"]["indices"].shape == (12, 2)
 | 
				
			||||||
    assert doc.activations["spancat"]["scores"].shape == (12, nO)
 | 
					    assert doc.activations["spancat"]["scores"].shape == (12, nO)
 | 
				
			||||||
    spancat.store_activations = True
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    spancat.store_activations = ["scores"]
 | 
					    spancat.set_store_activations(["scores"])
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert set(doc.activations["spancat"].keys()) == {"scores"}
 | 
					    assert set(doc.activations["spancat"].keys()) == {"scores"}
 | 
				
			||||||
    assert doc.activations["spancat"]["scores"].shape == (12, nO)
 | 
					    assert doc.activations["spancat"]["scores"].shape == (12, nO)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -225,14 +225,14 @@ def test_store_activations():
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert len(list(doc.activations["tagger"].keys())) == 0
 | 
					    assert len(list(doc.activations["tagger"].keys())) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tagger.store_activations = True
 | 
					    tagger.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert "tagger" in doc.activations
 | 
					    assert "tagger" in doc.activations
 | 
				
			||||||
    assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
 | 
					    assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
 | 
				
			||||||
    assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
 | 
					    assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
 | 
				
			||||||
    assert doc.activations["tagger"]["guesses"].shape == (5,)
 | 
					    assert doc.activations["tagger"]["guesses"].shape == (5,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    tagger.store_activations = ["probs"]
 | 
					    tagger.set_store_activations(["probs"])
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert set(doc.activations["tagger"].keys()) == {"probs"}
 | 
					    assert set(doc.activations["tagger"].keys()) == {"probs"}
 | 
				
			||||||
    assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
 | 
					    assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -888,12 +888,12 @@ def test_store_activations():
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert len(list(doc.activations["textcat"].keys())) == 0
 | 
					    assert len(list(doc.activations["textcat"].keys())) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    textcat.store_activations = True
 | 
					    textcat.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert list(doc.activations["textcat"].keys()) == ["probs"]
 | 
					    assert list(doc.activations["textcat"].keys()) == ["probs"]
 | 
				
			||||||
    assert doc.activations["textcat"]["probs"].shape == (nO,)
 | 
					    assert doc.activations["textcat"]["probs"].shape == (nO,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    textcat.store_activations = ["probs"]
 | 
					    textcat.set_store_activations(["probs"])
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert list(doc.activations["textcat"].keys()) == ["probs"]
 | 
					    assert list(doc.activations["textcat"].keys()) == ["probs"]
 | 
				
			||||||
    assert doc.activations["textcat"]["probs"].shape == (nO,)
 | 
					    assert doc.activations["textcat"]["probs"].shape == (nO,)
 | 
				
			||||||
| 
						 | 
					@ -913,12 +913,12 @@ def test_store_activations_multi():
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
 | 
					    assert len(list(doc.activations["textcat_multilabel"].keys())) == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    textcat.store_activations = True
 | 
					    textcat.set_store_activations(True)
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
 | 
					    assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
 | 
				
			||||||
    assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
 | 
					    assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    textcat.store_activations = ["probs"]
 | 
					    textcat.set_store_activations(["probs"])
 | 
				
			||||||
    doc = nlp("This is a test.")
 | 
					    doc = nlp("This is a test.")
 | 
				
			||||||
    assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
 | 
					    assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
 | 
				
			||||||
    assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
 | 
					    assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user