Store activations in Doc when store_activations is enabled

This change adds the new `activations` attribute to `Doc`. This
attribute can be used by trainable pipes to store their activations,
probabilities, and guesses for downstream users.

As an example, this change modifies the `tagger` and `senter` pipes to
add an `store_activations` option. When this option is enabled, the
probabilities and guesses are stored in `set_annotations`.
This commit is contained in:
Daniël de Kok 2022-06-22 09:58:29 +02:00
parent 0271306f16
commit b71c6043bc
8 changed files with 86 additions and 10 deletions

View File

@ -217,7 +217,7 @@ class Morphologizer(Tagger):
assert len(label_sample) > 0, Errors.E923.format(name=self.name) assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample) self.model.initialize(X=doc_sample, Y=label_sample)
def set_annotations(self, docs, batch_tag_ids): def set_annotations(self, docs, scores_guesses):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -225,6 +225,7 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#set_annotations DOCS: https://spacy.io/api/morphologizer#set_annotations
""" """
_, batch_tag_ids = scores_guesses
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc

View File

@ -38,11 +38,16 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"senter", "senter",
assigns=["token.is_sent_start"], assigns=["token.is_sent_start"],
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}}, default_config={
"model": DEFAULT_SENTER_MODEL,
"overwrite": False,
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
"store_activations": False
},
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
) )
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]): def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool):
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer) return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
def senter_score(examples, **kwargs): def senter_score(examples, **kwargs):
@ -72,6 +77,7 @@ class SentenceRecognizer(Tagger):
*, *,
overwrite=BACKWARD_OVERWRITE, overwrite=BACKWARD_OVERWRITE,
scorer=senter_score, scorer=senter_score,
store_activations=False,
): ):
"""Initialize a sentence recognizer. """Initialize a sentence recognizer.
@ -90,6 +96,7 @@ class SentenceRecognizer(Tagger):
self._rehearsal_model = None self._rehearsal_model = None
self.cfg = {"overwrite": overwrite} self.cfg = {"overwrite": overwrite}
self.scorer = scorer self.scorer = scorer
self.store_activations = store_activations
@property @property
def labels(self): def labels(self):
@ -107,7 +114,7 @@ class SentenceRecognizer(Tagger):
def label_data(self): def label_data(self):
return None return None
def set_annotations(self, docs, batch_tag_ids): def set_annotations(self, docs, scores_guesses):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -115,11 +122,17 @@ class SentenceRecognizer(Tagger):
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
""" """
_, batch_tag_ids = scores_guesses
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.store_activations:
doc.activations[self.name] = {
"probs": scores_guesses[0][i],
"guesses": scores_guesses[1][i],
}
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -45,7 +45,13 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"tagger", "tagger",
assigns=["token.tag"], assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"}, default_config={
"model": DEFAULT_TAGGER_MODEL,
"overwrite": False,
"scorer": {"@scorers": "spacy.tagger_scorer.v1"},
"neg_prefix": "!",
"store_activations": False
},
default_score_weights={"tag_acc": 1.0}, default_score_weights={"tag_acc": 1.0},
) )
def make_tagger( def make_tagger(
@ -55,6 +61,7 @@ def make_tagger(
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
neg_prefix: str, neg_prefix: str,
store_activations: bool,
): ):
"""Construct a part-of-speech tagger component. """Construct a part-of-speech tagger component.
@ -63,7 +70,7 @@ def make_tagger(
in size, and be normalized as probabilities (all scores between 0 and 1, in size, and be normalized as probabilities (all scores between 0 and 1,
with the rows summing to 1). with the rows summing to 1).
""" """
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix) return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, store_activations=store_activations)
def tagger_score(examples, **kwargs): def tagger_score(examples, **kwargs):
@ -89,6 +96,7 @@ class Tagger(TrainablePipe):
overwrite=BACKWARD_OVERWRITE, overwrite=BACKWARD_OVERWRITE,
scorer=tagger_score, scorer=tagger_score,
neg_prefix="!", neg_prefix="!",
store_activations=False,
): ):
"""Initialize a part-of-speech tagger. """Initialize a part-of-speech tagger.
@ -108,6 +116,7 @@ class Tagger(TrainablePipe):
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix} cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer self.scorer = scorer
self.store_activations = store_activations
@property @property
def labels(self): def labels(self):
@ -139,12 +148,12 @@ class Tagger(TrainablePipe):
n_labels = len(self.labels) n_labels = len(self.labels)
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs] guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
assert len(guesses) == len(docs) assert len(guesses) == len(docs)
return guesses return guesses, guesses
scores = self.model.predict(docs) scores = self.model.predict(docs)
assert len(scores) == len(docs), (len(scores), len(docs)) assert len(scores) == len(docs), (len(scores), len(docs))
guesses = self._scores2guesses(scores) guesses = self._scores2guesses(scores)
assert len(guesses) == len(docs) assert len(guesses) == len(docs)
return guesses return scores, guesses
def _scores2guesses(self, scores): def _scores2guesses(self, scores):
guesses = [] guesses = []
@ -155,7 +164,7 @@ class Tagger(TrainablePipe):
guesses.append(doc_guesses) guesses.append(doc_guesses)
return guesses return guesses
def set_annotations(self, docs, batch_tag_ids): def set_annotations(self, docs, scores_guesses):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -163,6 +172,7 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#set_annotations DOCS: https://spacy.io/api/tagger#set_annotations
""" """
_, batch_tag_ids = scores_guesses
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
@ -170,6 +180,11 @@ class Tagger(TrainablePipe):
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels labels = self.labels
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.store_activations:
doc.activations[self.name] = {
"probs": scores_guesses[0][i],
"guesses": scores_guesses[1][i],
}
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -6,3 +6,4 @@ cdef class TrainablePipe(Pipe):
cdef public object model cdef public object model
cdef public object cfg cdef public object cfg
cdef public object scorer cdef public object scorer
cdef public bint store_activations

View File

@ -1,3 +1,4 @@
from typing import cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from spacy.attrs import SENT_START from spacy.attrs import SENT_START
@ -6,6 +7,7 @@ from spacy import util
from spacy.training import Example from spacy.training import Example
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.tests.util import make_tempdir from spacy.tests.util import make_tempdir
@ -101,3 +103,23 @@ def test_overfitting_IO():
# test internal pipe labels vs. Language.pipe_labels with hidden labels # test internal pipe labels vs. Language.pipe_labels with hidden labels
assert nlp.get_pipe("senter").labels == ("I", "S") assert nlp.get_pipe("senter").labels == ("I", "S")
assert "senter" not in nlp.pipe_labels assert "senter" not in nlp.pipe_labels
def test_store_activations():
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
nlp = English()
senter = cast(TrainablePipe, nlp.add_pipe("senter"))
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
senter.store_activations = True
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
assert doc.activations["senter"]["probs"].shape == (5, 2)
assert doc.activations["senter"]["guesses"].shape == (5,)

View File

@ -1,3 +1,4 @@
from typing import cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from spacy.attrs import TAG from spacy.attrs import TAG
@ -6,6 +7,7 @@ from spacy import util
from spacy.training import Example from spacy.training import Example
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TrainablePipe
from thinc.api import compounding from thinc.api import compounding
from ..util import make_tempdir from ..util import make_tempdir
@ -211,6 +213,25 @@ def test_overfitting_IO():
assert doc3[0].tag_ != "N" assert doc3[0].tag_ != "N"
def test_store_activations():
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
nlp = English()
tagger = cast(TrainablePipe, nlp.add_pipe("tagger"))
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
tagger.store_activations = True
doc = nlp("This is a test.")
assert "tagger" in doc.activations
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
assert doc.activations["tagger"]["guesses"].shape == (5,)
def test_tagger_requires_labels(): def test_tagger_requires_labels():
nlp = English() nlp = English()
nlp.add_pipe("tagger") nlp.add_pipe("tagger")

View File

@ -50,6 +50,8 @@ cdef class Doc:
cdef public float sentiment cdef public float sentiment
cdef public dict activations
cdef public dict user_hooks cdef public dict user_hooks
cdef public dict user_token_hooks cdef public dict user_token_hooks
cdef public dict user_span_hooks cdef public dict user_span_hooks

View File

@ -245,6 +245,7 @@ cdef class Doc:
self.length = 0 self.length = 0
self.sentiment = 0.0 self.sentiment = 0.0
self.cats = {} self.cats = {}
self.activations = {}
self.user_hooks = {} self.user_hooks = {}
self.user_token_hooks = {} self.user_token_hooks = {}
self.user_span_hooks = {} self.user_span_hooks = {}