mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-01 04:46:38 +03:00
Store activations in Doc when store_activations
is enabled
This change adds the new `activations` attribute to `Doc`. This attribute can be used by trainable pipes to store their activations, probabilities, and guesses for downstream users. As an example, this change modifies the `tagger` and `senter` pipes to add an `store_activations` option. When this option is enabled, the probabilities and guesses are stored in `set_annotations`.
This commit is contained in:
parent
0271306f16
commit
b71c6043bc
|
@ -217,7 +217,7 @@ class Morphologizer(Tagger):
|
|||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
def set_annotations(self, docs, scores_guesses):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
|
@ -225,6 +225,7 @@ class Morphologizer(Tagger):
|
|||
|
||||
DOCS: https://spacy.io/api/morphologizer#set_annotations
|
||||
"""
|
||||
_, batch_tag_ids = scores_guesses
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
|
|
@ -38,11 +38,16 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"senter",
|
||||
assigns=["token.is_sent_start"],
|
||||
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
||||
default_config={
|
||||
"model": DEFAULT_SENTER_MODEL,
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
|
||||
"store_activations": False
|
||||
},
|
||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||
)
|
||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]):
|
||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool):
|
||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
|
||||
|
||||
|
||||
def senter_score(examples, **kwargs):
|
||||
|
@ -72,6 +77,7 @@ class SentenceRecognizer(Tagger):
|
|||
*,
|
||||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=senter_score,
|
||||
store_activations=False,
|
||||
):
|
||||
"""Initialize a sentence recognizer.
|
||||
|
||||
|
@ -90,6 +96,7 @@ class SentenceRecognizer(Tagger):
|
|||
self._rehearsal_model = None
|
||||
self.cfg = {"overwrite": overwrite}
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
@ -107,7 +114,7 @@ class SentenceRecognizer(Tagger):
|
|||
def label_data(self):
|
||||
return None
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
def set_annotations(self, docs, scores_guesses):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
|
@ -115,11 +122,17 @@ class SentenceRecognizer(Tagger):
|
|||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
|
||||
"""
|
||||
_, batch_tag_ids = scores_guesses
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
for i, doc in enumerate(docs):
|
||||
if self.store_activations:
|
||||
doc.activations[self.name] = {
|
||||
"probs": scores_guesses[0][i],
|
||||
"guesses": scores_guesses[1][i],
|
||||
}
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
|
|
|
@ -45,7 +45,13 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
|||
@Language.factory(
|
||||
"tagger",
|
||||
assigns=["token.tag"],
|
||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
|
||||
default_config={
|
||||
"model": DEFAULT_TAGGER_MODEL,
|
||||
"overwrite": False,
|
||||
"scorer": {"@scorers": "spacy.tagger_scorer.v1"},
|
||||
"neg_prefix": "!",
|
||||
"store_activations": False
|
||||
},
|
||||
default_score_weights={"tag_acc": 1.0},
|
||||
)
|
||||
def make_tagger(
|
||||
|
@ -55,6 +61,7 @@ def make_tagger(
|
|||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
neg_prefix: str,
|
||||
store_activations: bool,
|
||||
):
|
||||
"""Construct a part-of-speech tagger component.
|
||||
|
||||
|
@ -63,7 +70,7 @@ def make_tagger(
|
|||
in size, and be normalized as probabilities (all scores between 0 and 1,
|
||||
with the rows summing to 1).
|
||||
"""
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
|
||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, store_activations=store_activations)
|
||||
|
||||
|
||||
def tagger_score(examples, **kwargs):
|
||||
|
@ -89,6 +96,7 @@ class Tagger(TrainablePipe):
|
|||
overwrite=BACKWARD_OVERWRITE,
|
||||
scorer=tagger_score,
|
||||
neg_prefix="!",
|
||||
store_activations=False,
|
||||
):
|
||||
"""Initialize a part-of-speech tagger.
|
||||
|
||||
|
@ -108,6 +116,7 @@ class Tagger(TrainablePipe):
|
|||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
||||
self.cfg = dict(sorted(cfg.items()))
|
||||
self.scorer = scorer
|
||||
self.store_activations = store_activations
|
||||
|
||||
@property
|
||||
def labels(self):
|
||||
|
@ -139,12 +148,12 @@ class Tagger(TrainablePipe):
|
|||
n_labels = len(self.labels)
|
||||
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
|
||||
assert len(guesses) == len(docs)
|
||||
return guesses
|
||||
return guesses, guesses
|
||||
scores = self.model.predict(docs)
|
||||
assert len(scores) == len(docs), (len(scores), len(docs))
|
||||
guesses = self._scores2guesses(scores)
|
||||
assert len(guesses) == len(docs)
|
||||
return guesses
|
||||
return scores, guesses
|
||||
|
||||
def _scores2guesses(self, scores):
|
||||
guesses = []
|
||||
|
@ -155,7 +164,7 @@ class Tagger(TrainablePipe):
|
|||
guesses.append(doc_guesses)
|
||||
return guesses
|
||||
|
||||
def set_annotations(self, docs, batch_tag_ids):
|
||||
def set_annotations(self, docs, scores_guesses):
|
||||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
|
@ -163,6 +172,7 @@ class Tagger(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tagger#set_annotations
|
||||
"""
|
||||
_, batch_tag_ids = scores_guesses
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
cdef Doc doc
|
||||
|
@ -170,6 +180,11 @@ class Tagger(TrainablePipe):
|
|||
cdef bint overwrite = self.cfg["overwrite"]
|
||||
labels = self.labels
|
||||
for i, doc in enumerate(docs):
|
||||
if self.store_activations:
|
||||
doc.activations[self.name] = {
|
||||
"probs": scores_guesses[0][i],
|
||||
"guesses": scores_guesses[1][i],
|
||||
}
|
||||
doc_tag_ids = batch_tag_ids[i]
|
||||
if hasattr(doc_tag_ids, "get"):
|
||||
doc_tag_ids = doc_tag_ids.get()
|
||||
|
|
|
@ -6,3 +6,4 @@ cdef class TrainablePipe(Pipe):
|
|||
cdef public object model
|
||||
cdef public object cfg
|
||||
cdef public object scorer
|
||||
cdef public bint store_activations
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from typing import cast
|
||||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from spacy.attrs import SENT_START
|
||||
|
@ -6,6 +7,7 @@ from spacy import util
|
|||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import TrainablePipe
|
||||
from spacy.tests.util import make_tempdir
|
||||
|
||||
|
||||
|
@ -101,3 +103,23 @@ def test_overfitting_IO():
|
|||
# test internal pipe labels vs. Language.pipe_labels with hidden labels
|
||||
assert nlp.get_pipe("senter").labels == ("I", "S")
|
||||
assert "senter" not in nlp.pipe_labels
|
||||
|
||||
|
||||
def test_store_activations():
|
||||
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||
nlp = English()
|
||||
senter = cast(TrainablePipe, nlp.add_pipe("senter"))
|
||||
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
senter.store_activations = True
|
||||
|
||||
doc = nlp("This is a test.")
|
||||
assert "senter" in doc.activations
|
||||
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
||||
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
||||
assert doc.activations["senter"]["guesses"].shape == (5,)
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from typing import cast
|
||||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from spacy.attrs import TAG
|
||||
|
@ -6,6 +7,7 @@ from spacy import util
|
|||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import TrainablePipe
|
||||
from thinc.api import compounding
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
@ -211,6 +213,25 @@ def test_overfitting_IO():
|
|||
assert doc3[0].tag_ != "N"
|
||||
|
||||
|
||||
def test_store_activations():
|
||||
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
|
||||
nlp = English()
|
||||
tagger = cast(TrainablePipe, nlp.add_pipe("tagger"))
|
||||
train_examples = []
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
|
||||
tagger.store_activations = True
|
||||
|
||||
doc = nlp("This is a test.")
|
||||
|
||||
assert "tagger" in doc.activations
|
||||
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
||||
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
||||
|
||||
|
||||
def test_tagger_requires_labels():
|
||||
nlp = English()
|
||||
nlp.add_pipe("tagger")
|
||||
|
|
|
@ -50,6 +50,8 @@ cdef class Doc:
|
|||
|
||||
cdef public float sentiment
|
||||
|
||||
cdef public dict activations
|
||||
|
||||
cdef public dict user_hooks
|
||||
cdef public dict user_token_hooks
|
||||
cdef public dict user_span_hooks
|
||||
|
|
|
@ -245,6 +245,7 @@ cdef class Doc:
|
|||
self.length = 0
|
||||
self.sentiment = 0.0
|
||||
self.cats = {}
|
||||
self.activations = {}
|
||||
self.user_hooks = {}
|
||||
self.user_token_hooks = {}
|
||||
self.user_span_hooks = {}
|
||||
|
|
Loading…
Reference in New Issue
Block a user