mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-04 06:16:33 +03:00
Store activations in Doc when store_activations
is enabled
This change adds the new `activations` attribute to `Doc`. This attribute can be used by trainable pipes to store their activations, probabilities, and guesses for downstream users. As an example, this change modifies the `tagger` and `senter` pipes to add an `store_activations` option. When this option is enabled, the probabilities and guesses are stored in `set_annotations`.
This commit is contained in:
parent
0271306f16
commit
b71c6043bc
|
@ -217,7 +217,7 @@ class Morphologizer(Tagger):
|
||||||
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=doc_sample, Y=label_sample)
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
|
|
||||||
def set_annotations(self, docs, batch_tag_ids):
|
def set_annotations(self, docs, scores_guesses):
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
@ -225,6 +225,7 @@ class Morphologizer(Tagger):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/morphologizer#set_annotations
|
DOCS: https://spacy.io/api/morphologizer#set_annotations
|
||||||
"""
|
"""
|
||||||
|
_, batch_tag_ids = scores_guesses
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
|
|
@ -38,11 +38,16 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"senter",
|
"senter",
|
||||||
assigns=["token.is_sent_start"],
|
assigns=["token.is_sent_start"],
|
||||||
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
|
default_config={
|
||||||
|
"model": DEFAULT_SENTER_MODEL,
|
||||||
|
"overwrite": False,
|
||||||
|
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
|
||||||
|
"store_activations": False
|
||||||
|
},
|
||||||
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
|
||||||
)
|
)
|
||||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]):
|
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool):
|
||||||
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
|
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
|
||||||
|
|
||||||
|
|
||||||
def senter_score(examples, **kwargs):
|
def senter_score(examples, **kwargs):
|
||||||
|
@ -72,6 +77,7 @@ class SentenceRecognizer(Tagger):
|
||||||
*,
|
*,
|
||||||
overwrite=BACKWARD_OVERWRITE,
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
scorer=senter_score,
|
scorer=senter_score,
|
||||||
|
store_activations=False,
|
||||||
):
|
):
|
||||||
"""Initialize a sentence recognizer.
|
"""Initialize a sentence recognizer.
|
||||||
|
|
||||||
|
@ -90,6 +96,7 @@ class SentenceRecognizer(Tagger):
|
||||||
self._rehearsal_model = None
|
self._rehearsal_model = None
|
||||||
self.cfg = {"overwrite": overwrite}
|
self.cfg = {"overwrite": overwrite}
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -107,7 +114,7 @@ class SentenceRecognizer(Tagger):
|
||||||
def label_data(self):
|
def label_data(self):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set_annotations(self, docs, batch_tag_ids):
|
def set_annotations(self, docs, scores_guesses):
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
@ -115,11 +122,17 @@ class SentenceRecognizer(Tagger):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
|
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
|
||||||
"""
|
"""
|
||||||
|
_, batch_tag_ids = scores_guesses
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
if self.store_activations:
|
||||||
|
doc.activations[self.name] = {
|
||||||
|
"probs": scores_guesses[0][i],
|
||||||
|
"guesses": scores_guesses[1][i],
|
||||||
|
}
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
|
@ -45,7 +45,13 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"tagger",
|
"tagger",
|
||||||
assigns=["token.tag"],
|
assigns=["token.tag"],
|
||||||
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
|
default_config={
|
||||||
|
"model": DEFAULT_TAGGER_MODEL,
|
||||||
|
"overwrite": False,
|
||||||
|
"scorer": {"@scorers": "spacy.tagger_scorer.v1"},
|
||||||
|
"neg_prefix": "!",
|
||||||
|
"store_activations": False
|
||||||
|
},
|
||||||
default_score_weights={"tag_acc": 1.0},
|
default_score_weights={"tag_acc": 1.0},
|
||||||
)
|
)
|
||||||
def make_tagger(
|
def make_tagger(
|
||||||
|
@ -55,6 +61,7 @@ def make_tagger(
|
||||||
overwrite: bool,
|
overwrite: bool,
|
||||||
scorer: Optional[Callable],
|
scorer: Optional[Callable],
|
||||||
neg_prefix: str,
|
neg_prefix: str,
|
||||||
|
store_activations: bool,
|
||||||
):
|
):
|
||||||
"""Construct a part-of-speech tagger component.
|
"""Construct a part-of-speech tagger component.
|
||||||
|
|
||||||
|
@ -63,7 +70,7 @@ def make_tagger(
|
||||||
in size, and be normalized as probabilities (all scores between 0 and 1,
|
in size, and be normalized as probabilities (all scores between 0 and 1,
|
||||||
with the rows summing to 1).
|
with the rows summing to 1).
|
||||||
"""
|
"""
|
||||||
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
|
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, store_activations=store_activations)
|
||||||
|
|
||||||
|
|
||||||
def tagger_score(examples, **kwargs):
|
def tagger_score(examples, **kwargs):
|
||||||
|
@ -89,6 +96,7 @@ class Tagger(TrainablePipe):
|
||||||
overwrite=BACKWARD_OVERWRITE,
|
overwrite=BACKWARD_OVERWRITE,
|
||||||
scorer=tagger_score,
|
scorer=tagger_score,
|
||||||
neg_prefix="!",
|
neg_prefix="!",
|
||||||
|
store_activations=False,
|
||||||
):
|
):
|
||||||
"""Initialize a part-of-speech tagger.
|
"""Initialize a part-of-speech tagger.
|
||||||
|
|
||||||
|
@ -108,6 +116,7 @@ class Tagger(TrainablePipe):
|
||||||
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
|
||||||
self.cfg = dict(sorted(cfg.items()))
|
self.cfg = dict(sorted(cfg.items()))
|
||||||
self.scorer = scorer
|
self.scorer = scorer
|
||||||
|
self.store_activations = store_activations
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def labels(self):
|
def labels(self):
|
||||||
|
@ -139,12 +148,12 @@ class Tagger(TrainablePipe):
|
||||||
n_labels = len(self.labels)
|
n_labels = len(self.labels)
|
||||||
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
|
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
|
||||||
assert len(guesses) == len(docs)
|
assert len(guesses) == len(docs)
|
||||||
return guesses
|
return guesses, guesses
|
||||||
scores = self.model.predict(docs)
|
scores = self.model.predict(docs)
|
||||||
assert len(scores) == len(docs), (len(scores), len(docs))
|
assert len(scores) == len(docs), (len(scores), len(docs))
|
||||||
guesses = self._scores2guesses(scores)
|
guesses = self._scores2guesses(scores)
|
||||||
assert len(guesses) == len(docs)
|
assert len(guesses) == len(docs)
|
||||||
return guesses
|
return scores, guesses
|
||||||
|
|
||||||
def _scores2guesses(self, scores):
|
def _scores2guesses(self, scores):
|
||||||
guesses = []
|
guesses = []
|
||||||
|
@ -155,7 +164,7 @@ class Tagger(TrainablePipe):
|
||||||
guesses.append(doc_guesses)
|
guesses.append(doc_guesses)
|
||||||
return guesses
|
return guesses
|
||||||
|
|
||||||
def set_annotations(self, docs, batch_tag_ids):
|
def set_annotations(self, docs, scores_guesses):
|
||||||
"""Modify a batch of documents, using pre-computed scores.
|
"""Modify a batch of documents, using pre-computed scores.
|
||||||
|
|
||||||
docs (Iterable[Doc]): The documents to modify.
|
docs (Iterable[Doc]): The documents to modify.
|
||||||
|
@ -163,6 +172,7 @@ class Tagger(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/tagger#set_annotations
|
DOCS: https://spacy.io/api/tagger#set_annotations
|
||||||
"""
|
"""
|
||||||
|
_, batch_tag_ids = scores_guesses
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
@ -170,6 +180,11 @@ class Tagger(TrainablePipe):
|
||||||
cdef bint overwrite = self.cfg["overwrite"]
|
cdef bint overwrite = self.cfg["overwrite"]
|
||||||
labels = self.labels
|
labels = self.labels
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
|
if self.store_activations:
|
||||||
|
doc.activations[self.name] = {
|
||||||
|
"probs": scores_guesses[0][i],
|
||||||
|
"guesses": scores_guesses[1][i],
|
||||||
|
}
|
||||||
doc_tag_ids = batch_tag_ids[i]
|
doc_tag_ids = batch_tag_ids[i]
|
||||||
if hasattr(doc_tag_ids, "get"):
|
if hasattr(doc_tag_ids, "get"):
|
||||||
doc_tag_ids = doc_tag_ids.get()
|
doc_tag_ids = doc_tag_ids.get()
|
||||||
|
|
|
@ -6,3 +6,4 @@ cdef class TrainablePipe(Pipe):
|
||||||
cdef public object model
|
cdef public object model
|
||||||
cdef public object cfg
|
cdef public object cfg
|
||||||
cdef public object scorer
|
cdef public object scorer
|
||||||
|
cdef public bint store_activations
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import cast
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
from spacy.attrs import SENT_START
|
from spacy.attrs import SENT_START
|
||||||
|
@ -6,6 +7,7 @@ from spacy import util
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
|
from spacy.pipeline import TrainablePipe
|
||||||
from spacy.tests.util import make_tempdir
|
from spacy.tests.util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
|
@ -101,3 +103,23 @@ def test_overfitting_IO():
|
||||||
# test internal pipe labels vs. Language.pipe_labels with hidden labels
|
# test internal pipe labels vs. Language.pipe_labels with hidden labels
|
||||||
assert nlp.get_pipe("senter").labels == ("I", "S")
|
assert nlp.get_pipe("senter").labels == ("I", "S")
|
||||||
assert "senter" not in nlp.pipe_labels
|
assert "senter" not in nlp.pipe_labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_activations():
|
||||||
|
# Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
|
||||||
|
nlp = English()
|
||||||
|
senter = cast(TrainablePipe, nlp.add_pipe("senter"))
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
senter.store_activations = True
|
||||||
|
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
assert "senter" in doc.activations
|
||||||
|
assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
|
||||||
|
assert doc.activations["senter"]["probs"].shape == (5, 2)
|
||||||
|
assert doc.activations["senter"]["guesses"].shape == (5,)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from typing import cast
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
from spacy.attrs import TAG
|
from spacy.attrs import TAG
|
||||||
|
@ -6,6 +7,7 @@ from spacy import util
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
|
from spacy.pipeline import TrainablePipe
|
||||||
from thinc.api import compounding
|
from thinc.api import compounding
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
@ -211,6 +213,25 @@ def test_overfitting_IO():
|
||||||
assert doc3[0].tag_ != "N"
|
assert doc3[0].tag_ != "N"
|
||||||
|
|
||||||
|
|
||||||
|
def test_store_activations():
|
||||||
|
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
|
||||||
|
nlp = English()
|
||||||
|
tagger = cast(TrainablePipe, nlp.add_pipe("tagger"))
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
tagger.store_activations = True
|
||||||
|
|
||||||
|
doc = nlp("This is a test.")
|
||||||
|
|
||||||
|
assert "tagger" in doc.activations
|
||||||
|
assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
|
||||||
|
assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
|
||||||
|
assert doc.activations["tagger"]["guesses"].shape == (5,)
|
||||||
|
|
||||||
|
|
||||||
def test_tagger_requires_labels():
|
def test_tagger_requires_labels():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
nlp.add_pipe("tagger")
|
nlp.add_pipe("tagger")
|
||||||
|
|
|
@ -50,6 +50,8 @@ cdef class Doc:
|
||||||
|
|
||||||
cdef public float sentiment
|
cdef public float sentiment
|
||||||
|
|
||||||
|
cdef public dict activations
|
||||||
|
|
||||||
cdef public dict user_hooks
|
cdef public dict user_hooks
|
||||||
cdef public dict user_token_hooks
|
cdef public dict user_token_hooks
|
||||||
cdef public dict user_span_hooks
|
cdef public dict user_span_hooks
|
||||||
|
|
|
@ -245,6 +245,7 @@ cdef class Doc:
|
||||||
self.length = 0
|
self.length = 0
|
||||||
self.sentiment = 0.0
|
self.sentiment = 0.0
|
||||||
self.cats = {}
|
self.cats = {}
|
||||||
|
self.activations = {}
|
||||||
self.user_hooks = {}
|
self.user_hooks = {}
|
||||||
self.user_token_hooks = {}
|
self.user_token_hooks = {}
|
||||||
self.user_span_hooks = {}
|
self.user_span_hooks = {}
|
||||||
|
|
Loading…
Reference in New Issue
Block a user