Change type of store_activations to Union[bool, List[str]]

When the value is:

- A bool: all activations are stored when set to `True`.
- A List[str]: the activations named in the list are stored
This commit is contained in:
Daniël de Kok 2022-06-22 14:28:03 +02:00
parent b71c6043bc
commit c3da32b46b
8 changed files with 73 additions and 27 deletions

View File

@ -209,6 +209,7 @@ class Warnings(metaclass=ErrorsWithCodes):
"Only the last span group will be loaded under " "Only the last span group will be loaded under "
"Doc.spans['{group_name}']. Skipping span group with values: " "Doc.spans['{group_name}']. Skipping span group with values: "
"{group_values}") "{group_values}")
W121 = ("Activation '{activation}' is unknown for pipe '{pipe_name}'")
class Errors(metaclass=ErrorsWithCodes): class Errors(metaclass=ErrorsWithCodes):
@ -934,6 +935,7 @@ class Errors(metaclass=ErrorsWithCodes):
E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}") E1041 = ("Expected a string, Doc, or bytes as input, but got: {type}")
E1042 = ("Function was called with `{arg1}`={arg1_values} and " E1042 = ("Function was called with `{arg1}`={arg1_values} and "
"`{arg2}`={arg2_values} but these arguments are conflicting.") "`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("store_activations attribute must be set to List[str] or bool")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -217,7 +217,7 @@ class Morphologizer(Tagger):
assert len(label_sample) > 0, Errors.E923.format(name=self.name) assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample) self.model.initialize(X=doc_sample, Y=label_sample)
def set_annotations(self, docs, scores_guesses): def set_annotations(self, docs, activations):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -225,7 +225,7 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#set_annotations DOCS: https://spacy.io/api/morphologizer#set_annotations
""" """
_, batch_tag_ids = scores_guesses batch_tag_ids = activations["guesses"]
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc

View File

@ -1,5 +1,5 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Optional, Callable from typing import Optional, Callable, List, Union
from itertools import islice from itertools import islice
import srsly import srsly
@ -46,7 +46,12 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
}, },
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
) )
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool): def make_senter(nlp: Language,
name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable],
store_activations: Union[bool, List[str]]):
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations) return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
@ -114,7 +119,7 @@ class SentenceRecognizer(Tagger):
def label_data(self): def label_data(self):
return None return None
def set_annotations(self, docs, scores_guesses): def set_annotations(self, docs, activations):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -122,17 +127,15 @@ class SentenceRecognizer(Tagger):
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
""" """
_, batch_tag_ids = scores_guesses batch_tag_ids = activations["guesses"]
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name] = { for activation in self.store_activations:
"probs": scores_guesses[0][i], doc.activations[self.name][activation] = activations[activation][i]
"guesses": scores_guesses[1][i],
}
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()
@ -199,3 +202,7 @@ class SentenceRecognizer(Tagger):
def add_label(self, label, values=None): def add_label(self, label, values=None):
raise NotImplementedError raise NotImplementedError
@property
def activations(self):
return ["probs", "guesses"]

View File

@ -1,5 +1,5 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Callable, Optional from typing import Callable, List, Optional, Union
import numpy import numpy
import srsly import srsly
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
@ -61,7 +61,7 @@ def make_tagger(
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
neg_prefix: str, neg_prefix: str,
store_activations: bool, store_activations: Union[bool, List[str]],
): ):
"""Construct a part-of-speech tagger component. """Construct a part-of-speech tagger component.
@ -148,12 +148,12 @@ class Tagger(TrainablePipe):
n_labels = len(self.labels) n_labels = len(self.labels)
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs] guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
assert len(guesses) == len(docs) assert len(guesses) == len(docs)
return guesses, guesses return {"probs": guesses, "guesses": guesses}
scores = self.model.predict(docs) scores = self.model.predict(docs)
assert len(scores) == len(docs), (len(scores), len(docs)) assert len(scores) == len(docs), (len(scores), len(docs))
guesses = self._scores2guesses(scores) guesses = self._scores2guesses(scores)
assert len(guesses) == len(docs) assert len(guesses) == len(docs)
return scores, guesses return {"probs": scores, "guesses": guesses}
def _scores2guesses(self, scores): def _scores2guesses(self, scores):
guesses = [] guesses = []
@ -164,7 +164,7 @@ class Tagger(TrainablePipe):
guesses.append(doc_guesses) guesses.append(doc_guesses)
return guesses return guesses
def set_annotations(self, docs, scores_guesses): def set_annotations(self, docs, activations):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -172,7 +172,7 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#set_annotations DOCS: https://spacy.io/api/tagger#set_annotations
""" """
_, batch_tag_ids = scores_guesses batch_tag_ids = activations["guesses"]
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
@ -180,11 +180,9 @@ class Tagger(TrainablePipe):
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels labels = self.labels
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.store_activations: doc.activations[self.name] = {}
doc.activations[self.name] = { for activation in self.store_activations:
"probs": scores_guesses[0][i], doc.activations[self.name][activation] = activations[activation][i]
"guesses": scores_guesses[1][i],
}
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()
@ -338,3 +336,7 @@ class Tagger(TrainablePipe):
self.cfg["labels"].append(label) self.cfg["labels"].append(label)
self.vocab.strings.add(label) self.vocab.strings.add(label)
return 1 return 1
@property
def activations(self):
return ["probs", "guesses"]

View File

@ -6,4 +6,4 @@ cdef class TrainablePipe(Pipe):
cdef public object model cdef public object model
cdef public object cfg cdef public object cfg
cdef public object scorer cdef public object scorer
cdef public bint store_activations cdef object _store_activations

View File

@ -2,11 +2,12 @@
from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable
import srsly import srsly
from thinc.api import set_dropout_rate, Model, Optimizer from thinc.api import set_dropout_rate, Model, Optimizer
import warnings
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..training import validate_examples from ..training import validate_examples
from ..errors import Errors from ..errors import Errors, Warnings
from .pipe import Pipe, deserialize_config from .pipe import Pipe, deserialize_config
from .. import util from .. import util
from ..vocab import Vocab from ..vocab import Vocab
@ -342,3 +343,29 @@ cdef class TrainablePipe(Pipe):
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self
@property
def activations(self):
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="activations", name=self.name))
@property
def store_activations(self):
return self._store_activations
@store_activations.setter
def store_activations(self, activations):
known_activations = self.activations
if isinstance(activations, list):
self._store_activations = []
for activation in activations:
if activation in known_activations:
self._store_activations.append(activation)
else:
warnings.warn(Warnings.W121.format(activation=activation, pipe_name=self.name))
elif isinstance(activations, bool):
if activations:
self._store_activations = list(known_activations)
else:
self._store_activations = []
else:
raise ValueError(Errors.E1043)

View File

@ -117,9 +117,14 @@ def test_store_activations():
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
senter.store_activations = True senter.store_activations = True
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert "senter" in doc.activations assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"} assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
assert doc.activations["senter"]["probs"].shape == (5, 2) assert doc.activations["senter"]["probs"].shape == (5, 2)
assert doc.activations["senter"]["guesses"].shape == (5,) assert doc.activations["senter"]["guesses"].shape == (5,)
senter.store_activations = ["probs"]
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"probs"}
assert doc.activations["senter"]["probs"].shape == (5, 2)

View File

@ -223,14 +223,17 @@ def test_store_activations():
nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
tagger.store_activations = True tagger.store_activations = True
doc = nlp("This is a test.") doc = nlp("This is a test.")
assert "tagger" in doc.activations assert "tagger" in doc.activations
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"} assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS)) assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
assert doc.activations["tagger"]["guesses"].shape == (5,) assert doc.activations["tagger"]["guesses"].shape == (5,)
tagger.store_activations = ["probs"]
doc = nlp("This is a test.")
assert set(doc.activations["tagger"].keys()) == {"probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
def test_tagger_requires_labels(): def test_tagger_requires_labels():
nlp = English() nlp = English()