Make the TrainablePipe.store_activations property a bool

This means that we can also bring back `store_activations` setter.
This commit is contained in:
Daniël de Kok 2022-08-29 16:33:08 +02:00
parent 1cfbb934ed
commit aea53378dc
17 changed files with 81 additions and 144 deletions

View File

@ -65,7 +65,7 @@ def make_edit_tree_lemmatizer(
overwrite: bool, overwrite: bool,
top_k: int, top_k: int,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]], store_activations: bool,
): ):
"""Construct an EditTreeLemmatizer component.""" """Construct an EditTreeLemmatizer component."""
return EditTreeLemmatizer( return EditTreeLemmatizer(
@ -97,7 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
overwrite: bool = False, overwrite: bool = False,
top_k: int = 1, top_k: int = 1,
scorer: Optional[Callable] = lemmatizer_score, scorer: Optional[Callable] = lemmatizer_score,
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
): ):
""" """
Construct an edit tree lemmatizer. Construct an edit tree lemmatizer.
@ -109,8 +109,7 @@ class EditTreeLemmatizer(TrainablePipe):
frequency in the training data. frequency in the training data.
overwrite (bool): overwrite existing lemma annotations. overwrite (bool): overwrite existing lemma annotations.
top_k (int): try to apply at most the k most probable edit trees. top_k (int): try to apply at most the k most probable edit trees.
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations are: "probs" and "guesses".
""" """
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
@ -125,7 +124,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []} self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer self.scorer = scorer
self.set_store_activations(store_activations) self.store_activations = store_activations
def get_loss( def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d] self, examples: Iterable[Example], scores: List[Floats2d]
@ -202,9 +201,10 @@ class EditTreeLemmatizer(TrainablePipe):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
batch_tree_ids = activations["guesses"] batch_tree_ids = activations["guesses"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {} if self.store_activations:
for activation in self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name][activation] = activations[activation][i] for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tree_ids = batch_tree_ids[i] doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"): if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get() doc_tree_ids = doc_tree_ids.get()

View File

@ -85,7 +85,7 @@ def make_entity_linker(
scorer: Optional[Callable], scorer: Optional[Callable],
use_gold_ents: bool, use_gold_ents: bool,
threshold: Optional[float] = None, threshold: Optional[float] = None,
store_activations: Union[bool, List[str]], store_activations: bool,
): ):
"""Construct an EntityLinker component. """Construct an EntityLinker component.
@ -104,8 +104,7 @@ def make_entity_linker(
component must provide entity annotations. component must provide entity annotations.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
prediction is discarded. If None, predictions are not filtered by any threshold. prediction is discarded. If None, predictions are not filtered by any threshold.
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations are: "ents" and "scores".
""" """
if not model.attrs.get("include_span_maker", False): if not model.attrs.get("include_span_maker", False):
@ -174,7 +173,7 @@ class EntityLinker(TrainablePipe):
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool, use_gold_ents: bool,
threshold: Optional[float] = None, threshold: Optional[float] = None,
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
) -> None: ) -> None:
"""Initialize an entity linker. """Initialize an entity linker.
@ -223,7 +222,7 @@ class EntityLinker(TrainablePipe):
self.scorer = scorer self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.threshold = threshold self.threshold = threshold
self.set_store_activations(store_activations) self.store_activations = store_activations
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will """Define the KB of this pipe by providing a function that will
@ -551,12 +550,13 @@ class EntityLinker(TrainablePipe):
i = 0 i = 0
overwrite = self.cfg["overwrite"] overwrite = self.cfg["overwrite"]
for j, doc in enumerate(docs): for j, doc in enumerate(docs):
doc.activations[self.name] = {} if self.store_activations:
for activation in self.store_activations: doc.activations[self.name] = {}
# We only copy activations that are Ragged. for act_name, acts in activations.items():
doc.activations[self.name][activation] = cast( if act_name != "kb_ids":
Ragged, activations[activation][j] # We only copy activations that are Ragged.
) doc.activations[self.name][act_name] = cast(Ragged, acts[j])
for ent in doc.ents: for ent in doc.ents:
kb_id = kb_ids[i] kb_id = kb_ids[i]
i += 1 i += 1
@ -668,7 +668,7 @@ class EntityLinker(TrainablePipe):
doc_scores: List[Floats1d], doc_scores: List[Floats1d],
doc_ents: List[Ints1d], doc_ents: List[Ints1d],
): ):
if len(self.store_activations) == 0: if not self.store_activations:
return return
ops = self.model.ops ops = self.model.ops
lengths = ops.asarray1i([s.shape[0] for s in doc_scores]) lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
@ -683,7 +683,7 @@ class EntityLinker(TrainablePipe):
scores: Sequence[float], scores: Sequence[float],
ents: Sequence[int], ents: Sequence[int],
): ):
if len(self.store_activations) == 0: if not self.store_activations:
return return
ops = self.model.ops ops = self.model.ops
doc_scores.append(ops.asarray1f(scores)) doc_scores.append(ops.asarray1f(scores))

View File

@ -69,7 +69,7 @@ def make_morphologizer(
overwrite: bool, overwrite: bool,
extend: bool, extend: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]], store_activations: bool,
): ):
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer, return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer,
store_activations=store_activations) store_activations=store_activations)
@ -104,7 +104,7 @@ class Morphologizer(Tagger):
overwrite: bool = BACKWARD_OVERWRITE, overwrite: bool = BACKWARD_OVERWRITE,
extend: bool = BACKWARD_EXTEND, extend: bool = BACKWARD_EXTEND,
scorer: Optional[Callable] = morphologizer_score, scorer: Optional[Callable] = morphologizer_score,
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
): ):
"""Initialize a morphologizer. """Initialize a morphologizer.
@ -115,8 +115,7 @@ class Morphologizer(Tagger):
scorer (Optional[Callable]): The scoring method. Defaults to scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_token_attr for the attributes "pos" and "morph" and Scorer.score_token_attr for the attributes "pos" and "morph" and
Scorer.score_token_attr_per_feat for the attribute "morph". Scorer.score_token_attr_per_feat for the attribute "morph".
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations are: "probs" and "guesses".
DOCS: https://spacy.io/api/morphologizer#init DOCS: https://spacy.io/api/morphologizer#init
""" """
@ -136,7 +135,7 @@ class Morphologizer(Tagger):
} }
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer self.scorer = scorer
self.set_store_activations(store_activations) self.store_activations = store_activations
@property @property
def labels(self): def labels(self):
@ -250,9 +249,10 @@ class Morphologizer(Tagger):
# to allocate a compatible container out of the iterable. # to allocate a compatible container out of the iterable.
labels = tuple(self.labels) labels = tuple(self.labels)
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {} if self.store_activations:
for activation in self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name][activation] = activations[activation][i] for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -52,7 +52,7 @@ def make_senter(nlp: Language,
model: Model, model: Model,
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]]): store_activations: bool):
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations) return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
@ -83,7 +83,7 @@ class SentenceRecognizer(Tagger):
*, *,
overwrite=BACKWARD_OVERWRITE, overwrite=BACKWARD_OVERWRITE,
scorer=senter_score, scorer=senter_score,
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
): ):
"""Initialize a sentence recognizer. """Initialize a sentence recognizer.
@ -93,8 +93,7 @@ class SentenceRecognizer(Tagger):
losses during training. losses during training.
scorer (Optional[Callable]): The scoring method. Defaults to scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_spans for the attribute "sents". Scorer.score_spans for the attribute "sents".
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations are: "probs" and "guesses".
DOCS: https://spacy.io/api/sentencerecognizer#init DOCS: https://spacy.io/api/sentencerecognizer#init
""" """
@ -104,7 +103,7 @@ class SentenceRecognizer(Tagger):
self._rehearsal_model = None self._rehearsal_model = None
self.cfg = {"overwrite": overwrite} self.cfg = {"overwrite": overwrite}
self.scorer = scorer self.scorer = scorer
self.set_store_activations(store_activations) self.store_activations = store_activations
@property @property
def labels(self): def labels(self):
@ -136,9 +135,10 @@ class SentenceRecognizer(Tagger):
cdef Doc doc cdef Doc doc
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {} if self.store_activations:
for activation in self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name][activation] = activations[activation][i] for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -120,7 +120,7 @@ def make_spancat(
scorer: Optional[Callable], scorer: Optional[Callable],
threshold: float, threshold: float,
max_positive: Optional[int], max_positive: Optional[int],
store_activations: Union[bool, List[str]], store_activations: bool,
) -> "SpanCategorizer": ) -> "SpanCategorizer":
"""Create a SpanCategorizer component. The span categorizer consists of two """Create a SpanCategorizer component. The span categorizer consists of two
parts: a suggester function that proposes candidate spans, and a labeller parts: a suggester function that proposes candidate spans, and a labeller
@ -141,8 +141,7 @@ def make_spancat(
0.5. 0.5.
max_positive (Optional[int]): Maximum number of labels to consider positive max_positive (Optional[int]): Maximum number of labels to consider positive
per span. Defaults to None, indicating no limit. per span. Defaults to None, indicating no limit.
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations are: "indices" and "scores".
""" """
return SpanCategorizer( return SpanCategorizer(
nlp.vocab, nlp.vocab,
@ -192,7 +191,7 @@ class SpanCategorizer(TrainablePipe):
threshold: float = 0.5, threshold: float = 0.5,
max_positive: Optional[int] = None, max_positive: Optional[int] = None,
scorer: Optional[Callable] = spancat_score, scorer: Optional[Callable] = spancat_score,
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
) -> None: ) -> None:
"""Initialize the span categorizer. """Initialize the span categorizer.
vocab (Vocab): The shared vocabulary. vocab (Vocab): The shared vocabulary.
@ -225,7 +224,7 @@ class SpanCategorizer(TrainablePipe):
self.model = model self.model = model
self.name = name self.name = name
self.scorer = scorer self.scorer = scorer
self.set_store_activations(store_activations) self.store_activations = store_activations
@property @property
def key(self) -> str: def key(self) -> str:
@ -317,10 +316,9 @@ class SpanCategorizer(TrainablePipe):
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
indices_i = indices[i].dataXd indices_i = indices[i].dataXd
doc.activations[self.name] = {} if self.store_activations:
if "indices" in self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name]["indices"] = indices_i doc.activations[self.name]["indices"] = indices_i
if "scores" in self.store_activations:
doc.activations[self.name]["scores"] = scores[ doc.activations[self.name]["scores"] = scores[
offset : offset + indices.lengths[i] offset : offset + indices.lengths[i]
] ]

View File

@ -61,7 +61,7 @@ def make_tagger(
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
neg_prefix: str, neg_prefix: str,
store_activations: Union[bool, List[str]], store_activations: bool,
): ):
"""Construct a part-of-speech tagger component. """Construct a part-of-speech tagger component.
@ -97,7 +97,7 @@ class Tagger(TrainablePipe):
overwrite=BACKWARD_OVERWRITE, overwrite=BACKWARD_OVERWRITE,
scorer=tagger_score, scorer=tagger_score,
neg_prefix="!", neg_prefix="!",
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
): ):
"""Initialize a part-of-speech tagger. """Initialize a part-of-speech tagger.
@ -107,8 +107,7 @@ class Tagger(TrainablePipe):
losses during training. losses during training.
scorer (Optional[Callable]): The scoring method. Defaults to scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_token_attr for the attribute "tag". Scorer.score_token_attr for the attribute "tag".
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations are: "probs" and "guesses".
DOCS: https://spacy.io/api/tagger#init DOCS: https://spacy.io/api/tagger#init
""" """
@ -119,7 +118,7 @@ class Tagger(TrainablePipe):
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix} cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer self.scorer = scorer
self.set_store_activations(store_activations) self.store_activations = store_activations
@property @property
def labels(self): def labels(self):
@ -183,9 +182,10 @@ class Tagger(TrainablePipe):
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels labels = self.labels
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {} if self.store_activations:
for activation in self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name][activation] = activations[activation][i] for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -97,7 +97,7 @@ def make_textcat(
model: Model[List[Doc], List[Floats2d]], model: Model[List[Doc], List[Floats2d]],
threshold: float, threshold: float,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]], store_activations: bool,
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories """Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered over a whole document. It can learn one or more labels, and the labels are considered
@ -107,8 +107,7 @@ def make_textcat(
scores for each category. scores for each category.
threshold (float): Cutoff to consider a prediction "positive". threshold (float): Cutoff to consider a prediction "positive".
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations is: "probs".
""" """
return TextCategorizer( return TextCategorizer(
nlp.vocab, nlp.vocab,
@ -148,7 +147,7 @@ class TextCategorizer(TrainablePipe):
*, *,
threshold: float, threshold: float,
scorer: Optional[Callable] = textcat_score, scorer: Optional[Callable] = textcat_score,
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
) -> None: ) -> None:
"""Initialize a text categorizer for single-label classification. """Initialize a text categorizer for single-label classification.
@ -169,7 +168,7 @@ class TextCategorizer(TrainablePipe):
cfg = {"labels": [], "threshold": threshold, "positive_label": None} cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
self.set_store_activations(store_activations) self.store_activations = store_activations
@property @property
def support_missing_values(self): def support_missing_values(self):
@ -224,8 +223,8 @@ class TextCategorizer(TrainablePipe):
""" """
probs = activations["probs"] probs = activations["probs"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
doc.activations[self.name] = {} if self.store_activations:
if "probs" in self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name]["probs"] = probs[i] doc.activations[self.name]["probs"] = probs[i]
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
doc.cats[label] = float(probs[i, j]) doc.cats[label] = float(probs[i, j])

View File

@ -97,7 +97,7 @@ def make_multilabel_textcat(
model: Model[List[Doc], List[Floats2d]], model: Model[List[Doc], List[Floats2d]],
threshold: float, threshold: float,
scorer: Optional[Callable], scorer: Optional[Callable],
store_activations: Union[bool, List[str]], store_activations: bool,
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories """Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered over a whole document. It can learn one or more labels, and the labels are considered
@ -146,7 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
*, *,
threshold: float, threshold: float,
scorer: Optional[Callable] = textcat_multilabel_score, scorer: Optional[Callable] = textcat_multilabel_score,
store_activations: Union[bool, List[str]] = False, store_activations: bool = False,
) -> None: ) -> None:
"""Initialize a text categorizer for multi-label classification. """Initialize a text categorizer for multi-label classification.
@ -155,8 +155,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
threshold (float): Cutoff to consider a prediction "positive". threshold (float): Cutoff to consider a prediction "positive".
store_activations (Union[bool, List[str]]): Model activations to store in store_activations (bool): store model activations in Doc when annotating.
Doc when annotating. supported activations is: "probs".
DOCS: https://spacy.io/api/textcategorizer#init DOCS: https://spacy.io/api/textcategorizer#init
""" """
@ -167,7 +166,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
cfg = {"labels": [], "threshold": threshold} cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
self.set_store_activations(store_activations) self.store_activations = store_activations
@property @property
def support_missing_values(self): def support_missing_values(self):

View File

@ -6,4 +6,4 @@ cdef class TrainablePipe(Pipe):
cdef public object model cdef public object model
cdef public object cfg cdef public object cfg
cdef public object scorer cdef public object scorer
cdef object _store_activations cdef bint _store_activations

View File

@ -352,19 +352,6 @@ cdef class TrainablePipe(Pipe):
def store_activations(self): def store_activations(self):
return self._store_activations return self._store_activations
def set_store_activations(self, activations): @store_activations.setter
known_activations = self.activations def store_activations(self, store_activations: bool):
if isinstance(activations, list): self._store_activations = store_activations
self._store_activations = []
for activation in activations:
if activation in known_activations:
self._store_activations.append(activation)
else:
warnings.warn(Warnings.W400.format(activation=activation, pipe_name=self.name))
elif isinstance(activations, bool):
if activations:
self._store_activations = list(known_activations)
else:
self._store_activations = []
else:
raise ValueError(Errors.E1400)

View File

@ -293,15 +293,10 @@ def test_store_activations():
nO = lemmatizer.model.get_dim("nO") nO = lemmatizer.model.get_dim("nO")
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert len(list(doc.activations["trainable_lemmatizer"].keys())) == 0 assert "trainable_lemmatizer" not in doc.activations
lemmatizer.set_store_activations(True) lemmatizer.store_activations = True
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"] assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs", "guesses"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO) assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)
assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,) assert doc.activations["trainable_lemmatizer"]["guesses"].shape == (5,)
lemmatizer.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == ["probs"]
assert doc.activations["trainable_lemmatizer"]["probs"].shape == (5, nO)

View File

@ -1225,9 +1225,9 @@ def test_store_activations():
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
doc = nlp("Russ Cochran was a publisher") doc = nlp("Russ Cochran was a publisher")
assert len(doc.activations["entity_linker"].keys()) == 0 assert "entity_linker" not in doc.activations
entity_linker.set_store_activations(True) entity_linker.store_activations = True
doc = nlp("Russ Cochran was a publisher") doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"} assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
ents = doc.activations["entity_linker"]["ents"] ents = doc.activations["entity_linker"]["ents"]
@ -1240,12 +1240,3 @@ def test_store_activations():
assert scores.data.shape == (2, 1) assert scores.data.shape == (2, 1)
assert scores.data.dtype == "float32" assert scores.data.dtype == "float32"
assert scores.lengths.shape == (1,) assert scores.lengths.shape == (1,)
entity_linker.set_store_activations(["scores"])
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"scores"}
scores = doc.activations["entity_linker"]["scores"]
assert isinstance(scores, Ragged)
assert scores.data.shape == (2, 1)
assert scores.data.dtype == "float32"
assert scores.lengths.shape == (1,)

View File

@ -211,17 +211,11 @@ def test_store_activations():
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert len(list(doc.activations["morphologizer"].keys())) == 0 assert "morphologizer" not in doc.activations
morphologizer.set_store_activations(True) morphologizer.store_activations = True
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert "morphologizer" in doc.activations assert "morphologizer" in doc.activations
assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"} assert set(doc.activations["morphologizer"].keys()) == {"guesses", "probs"}
assert doc.activations["morphologizer"]["probs"].shape == (5, 6) assert doc.activations["morphologizer"]["probs"].shape == (5, 6)
assert doc.activations["morphologizer"]["guesses"].shape == (5,) assert doc.activations["morphologizer"]["guesses"].shape == (5,)
morphologizer.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert "morphologizer" in doc.activations
assert set(doc.activations["morphologizer"].keys()) == {"probs"}
assert doc.activations["morphologizer"]["probs"].shape == (5, 6)

View File

@ -118,17 +118,11 @@ def test_store_activations():
nO = senter.model.get_dim("nO") nO = senter.model.get_dim("nO")
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert len(list(doc.activations["senter"].keys())) == 0 assert "senter" not in doc.activations
senter.set_store_activations(True) senter.store_activations = True
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert "senter" in doc.activations assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"} assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
assert doc.activations["senter"]["probs"].shape == (5, nO) assert doc.activations["senter"]["probs"].shape == (5, nO)
assert doc.activations["senter"]["guesses"].shape == (5,) assert doc.activations["senter"]["guesses"].shape == (5,)
senter.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"probs"}
assert doc.activations["senter"]["probs"].shape == (5, 2)

View File

@ -432,15 +432,10 @@ def test_store_activations():
assert set(spancat.labels) == {"LOC", "PERSON"} assert set(spancat.labels) == {"LOC", "PERSON"}
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert len(list(doc.activations["spancat"].keys())) == 0 assert "spancat" not in doc.activations
spancat.set_store_activations(True) spancat.store_activations = True
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"} assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
assert doc.activations["spancat"]["indices"].shape == (12, 2) assert doc.activations["spancat"]["indices"].shape == (12, 2)
assert doc.activations["spancat"]["scores"].shape == (12, nO) assert doc.activations["spancat"]["scores"].shape == (12, nO)
spancat.set_store_activations(["scores"])
doc = nlp("This is a test.")
assert set(doc.activations["spancat"].keys()) == {"scores"}
assert doc.activations["spancat"]["scores"].shape == (12, nO)

View File

@ -223,20 +223,15 @@ def test_store_activations():
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert len(list(doc.activations["tagger"].keys())) == 0 assert "tagger" not in doc.activations
tagger.set_store_activations(True) tagger.store_activations = True
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert "tagger" in doc.activations assert "tagger" in doc.activations
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"} assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
assert doc.activations["tagger"]["guesses"].shape == (5,) assert doc.activations["tagger"]["guesses"].shape == (5,)
tagger.set_store_activations(["probs"])
doc = nlp("This is a test.")
assert set(doc.activations["tagger"].keys()) == {"probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
def test_tagger_requires_labels(): def test_tagger_requires_labels():
nlp = English() nlp = English()

View File

@ -886,14 +886,9 @@ def test_store_activations():
nO = textcat.model.get_dim("nO") nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert len(list(doc.activations["textcat"].keys())) == 0 assert "textcat" not in doc.activations
textcat.set_store_activations(True) textcat.store_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,)
textcat.set_store_activations(["probs"])
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probs"] assert list(doc.activations["textcat"].keys()) == ["probs"]
assert doc.activations["textcat"]["probs"].shape == (nO,) assert doc.activations["textcat"]["probs"].shape == (nO,)
@ -911,14 +906,9 @@ def test_store_activations_multi():
nO = textcat.model.get_dim("nO") nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert len(list(doc.activations["textcat_multilabel"].keys())) == 0 assert "textcat_multilabel" not in doc.activations
textcat.set_store_activations(True) textcat.store_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)
textcat.set_store_activations(["probs"])
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"] assert list(doc.activations["textcat_multilabel"].keys()) == ["probs"]
assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,) assert doc.activations["textcat_multilabel"]["probs"].shape == (nO,)