mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Store activations in Doc when store_activations is enabled
				
					
				
			This change adds the new `activations` attribute to `Doc`. This attribute can be used by trainable pipes to store their activations, probabilities, and guesses for downstream users. As an example, this change modifies the `tagger` and `senter` pipes to add an `store_activations` option. When this option is enabled, the probabilities and guesses are stored in `set_annotations`.
This commit is contained in:
		
							parent
							
								
									0271306f16
								
							
						
					
					
						commit
						b71c6043bc
					
				| 
						 | 
				
			
			@ -217,7 +217,7 @@ class Morphologizer(Tagger):
 | 
			
		|||
        assert len(label_sample) > 0, Errors.E923.format(name=self.name)
 | 
			
		||||
        self.model.initialize(X=doc_sample, Y=label_sample)
 | 
			
		||||
 | 
			
		||||
    def set_annotations(self, docs, batch_tag_ids):
 | 
			
		||||
    def set_annotations(self, docs, scores_guesses):
 | 
			
		||||
        """Modify a batch of documents, using pre-computed scores.
 | 
			
		||||
 | 
			
		||||
        docs (Iterable[Doc]): The documents to modify.
 | 
			
		||||
| 
						 | 
				
			
			@ -225,6 +225,7 @@ class Morphologizer(Tagger):
 | 
			
		|||
 | 
			
		||||
        DOCS: https://spacy.io/api/morphologizer#set_annotations
 | 
			
		||||
        """
 | 
			
		||||
        _, batch_tag_ids = scores_guesses
 | 
			
		||||
        if isinstance(docs, Doc):
 | 
			
		||||
            docs = [docs]
 | 
			
		||||
        cdef Doc doc
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -38,11 +38,16 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
 | 
			
		|||
@Language.factory(
 | 
			
		||||
    "senter",
 | 
			
		||||
    assigns=["token.is_sent_start"],
 | 
			
		||||
    default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}},
 | 
			
		||||
    default_config={
 | 
			
		||||
        "model": DEFAULT_SENTER_MODEL,
 | 
			
		||||
        "overwrite": False,
 | 
			
		||||
        "scorer": {"@scorers": "spacy.senter_scorer.v1"},
 | 
			
		||||
        "store_activations": False
 | 
			
		||||
    },
 | 
			
		||||
    default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
 | 
			
		||||
)
 | 
			
		||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]):
 | 
			
		||||
    return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer)
 | 
			
		||||
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable], store_activations: bool):
 | 
			
		||||
    return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, store_activations=store_activations)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def senter_score(examples, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -72,6 +77,7 @@ class SentenceRecognizer(Tagger):
 | 
			
		|||
        *,
 | 
			
		||||
        overwrite=BACKWARD_OVERWRITE,
 | 
			
		||||
        scorer=senter_score,
 | 
			
		||||
        store_activations=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """Initialize a sentence recognizer.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -90,6 +96,7 @@ class SentenceRecognizer(Tagger):
 | 
			
		|||
        self._rehearsal_model = None
 | 
			
		||||
        self.cfg = {"overwrite": overwrite}
 | 
			
		||||
        self.scorer = scorer
 | 
			
		||||
        self.store_activations = store_activations
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def labels(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -107,7 +114,7 @@ class SentenceRecognizer(Tagger):
 | 
			
		|||
    def label_data(self):
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
    def set_annotations(self, docs, batch_tag_ids):
 | 
			
		||||
    def set_annotations(self, docs, scores_guesses):
 | 
			
		||||
        """Modify a batch of documents, using pre-computed scores.
 | 
			
		||||
 | 
			
		||||
        docs (Iterable[Doc]): The documents to modify.
 | 
			
		||||
| 
						 | 
				
			
			@ -115,11 +122,17 @@ class SentenceRecognizer(Tagger):
 | 
			
		|||
 | 
			
		||||
        DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
 | 
			
		||||
        """
 | 
			
		||||
        _, batch_tag_ids = scores_guesses
 | 
			
		||||
        if isinstance(docs, Doc):
 | 
			
		||||
            docs = [docs]
 | 
			
		||||
        cdef Doc doc
 | 
			
		||||
        cdef bint overwrite = self.cfg["overwrite"]
 | 
			
		||||
        for i, doc in enumerate(docs):
 | 
			
		||||
            if self.store_activations:
 | 
			
		||||
                doc.activations[self.name] = {
 | 
			
		||||
                    "probs": scores_guesses[0][i],
 | 
			
		||||
                    "guesses": scores_guesses[1][i],
 | 
			
		||||
                }
 | 
			
		||||
            doc_tag_ids = batch_tag_ids[i]
 | 
			
		||||
            if hasattr(doc_tag_ids, "get"):
 | 
			
		||||
                doc_tag_ids = doc_tag_ids.get()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -45,7 +45,13 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
 | 
			
		|||
@Language.factory(
 | 
			
		||||
    "tagger",
 | 
			
		||||
    assigns=["token.tag"],
 | 
			
		||||
    default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"},
 | 
			
		||||
    default_config={
 | 
			
		||||
        "model": DEFAULT_TAGGER_MODEL,
 | 
			
		||||
        "overwrite": False,
 | 
			
		||||
        "scorer": {"@scorers": "spacy.tagger_scorer.v1"},
 | 
			
		||||
        "neg_prefix": "!",
 | 
			
		||||
        "store_activations": False
 | 
			
		||||
    },
 | 
			
		||||
    default_score_weights={"tag_acc": 1.0},
 | 
			
		||||
)
 | 
			
		||||
def make_tagger(
 | 
			
		||||
| 
						 | 
				
			
			@ -55,6 +61,7 @@ def make_tagger(
 | 
			
		|||
    overwrite: bool,
 | 
			
		||||
    scorer: Optional[Callable],
 | 
			
		||||
    neg_prefix: str,
 | 
			
		||||
    store_activations: bool,
 | 
			
		||||
):
 | 
			
		||||
    """Construct a part-of-speech tagger component.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -63,7 +70,7 @@ def make_tagger(
 | 
			
		|||
        in size, and be normalized as probabilities (all scores between 0 and 1,
 | 
			
		||||
        with the rows summing to 1).
 | 
			
		||||
    """
 | 
			
		||||
    return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix)
 | 
			
		||||
    return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix, store_activations=store_activations)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def tagger_score(examples, **kwargs):
 | 
			
		||||
| 
						 | 
				
			
			@ -89,6 +96,7 @@ class Tagger(TrainablePipe):
 | 
			
		|||
        overwrite=BACKWARD_OVERWRITE,
 | 
			
		||||
        scorer=tagger_score,
 | 
			
		||||
        neg_prefix="!",
 | 
			
		||||
        store_activations=False,
 | 
			
		||||
    ):
 | 
			
		||||
        """Initialize a part-of-speech tagger.
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -108,6 +116,7 @@ class Tagger(TrainablePipe):
 | 
			
		|||
        cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
 | 
			
		||||
        self.cfg = dict(sorted(cfg.items()))
 | 
			
		||||
        self.scorer = scorer
 | 
			
		||||
        self.store_activations = store_activations
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def labels(self):
 | 
			
		||||
| 
						 | 
				
			
			@ -139,12 +148,12 @@ class Tagger(TrainablePipe):
 | 
			
		|||
            n_labels = len(self.labels)
 | 
			
		||||
            guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
 | 
			
		||||
            assert len(guesses) == len(docs)
 | 
			
		||||
            return guesses
 | 
			
		||||
            return guesses, guesses
 | 
			
		||||
        scores = self.model.predict(docs)
 | 
			
		||||
        assert len(scores) == len(docs), (len(scores), len(docs))
 | 
			
		||||
        guesses = self._scores2guesses(scores)
 | 
			
		||||
        assert len(guesses) == len(docs)
 | 
			
		||||
        return guesses
 | 
			
		||||
        return scores, guesses
 | 
			
		||||
 | 
			
		||||
    def _scores2guesses(self, scores):
 | 
			
		||||
        guesses = []
 | 
			
		||||
| 
						 | 
				
			
			@ -155,7 +164,7 @@ class Tagger(TrainablePipe):
 | 
			
		|||
            guesses.append(doc_guesses)
 | 
			
		||||
        return guesses
 | 
			
		||||
 | 
			
		||||
    def set_annotations(self, docs, batch_tag_ids):
 | 
			
		||||
    def set_annotations(self, docs, scores_guesses):
 | 
			
		||||
        """Modify a batch of documents, using pre-computed scores.
 | 
			
		||||
 | 
			
		||||
        docs (Iterable[Doc]): The documents to modify.
 | 
			
		||||
| 
						 | 
				
			
			@ -163,6 +172,7 @@ class Tagger(TrainablePipe):
 | 
			
		|||
 | 
			
		||||
        DOCS: https://spacy.io/api/tagger#set_annotations
 | 
			
		||||
        """
 | 
			
		||||
        _, batch_tag_ids = scores_guesses
 | 
			
		||||
        if isinstance(docs, Doc):
 | 
			
		||||
            docs = [docs]
 | 
			
		||||
        cdef Doc doc
 | 
			
		||||
| 
						 | 
				
			
			@ -170,6 +180,11 @@ class Tagger(TrainablePipe):
 | 
			
		|||
        cdef bint overwrite = self.cfg["overwrite"]
 | 
			
		||||
        labels = self.labels
 | 
			
		||||
        for i, doc in enumerate(docs):
 | 
			
		||||
            if self.store_activations:
 | 
			
		||||
                doc.activations[self.name] = {
 | 
			
		||||
                    "probs": scores_guesses[0][i],
 | 
			
		||||
                    "guesses": scores_guesses[1][i],
 | 
			
		||||
                }
 | 
			
		||||
            doc_tag_ids = batch_tag_ids[i]
 | 
			
		||||
            if hasattr(doc_tag_ids, "get"):
 | 
			
		||||
                doc_tag_ids = doc_tag_ids.get()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -6,3 +6,4 @@ cdef class TrainablePipe(Pipe):
 | 
			
		|||
    cdef public object model
 | 
			
		||||
    cdef public object cfg
 | 
			
		||||
    cdef public object scorer
 | 
			
		||||
    cdef public bint store_activations
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,3 +1,4 @@
 | 
			
		|||
from typing import cast
 | 
			
		||||
import pytest
 | 
			
		||||
from numpy.testing import assert_equal
 | 
			
		||||
from spacy.attrs import SENT_START
 | 
			
		||||
| 
						 | 
				
			
			@ -6,6 +7,7 @@ from spacy import util
 | 
			
		|||
from spacy.training import Example
 | 
			
		||||
from spacy.lang.en import English
 | 
			
		||||
from spacy.language import Language
 | 
			
		||||
from spacy.pipeline import TrainablePipe
 | 
			
		||||
from spacy.tests.util import make_tempdir
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -101,3 +103,23 @@ def test_overfitting_IO():
 | 
			
		|||
    # test internal pipe labels vs. Language.pipe_labels with hidden labels
 | 
			
		||||
    assert nlp.get_pipe("senter").labels == ("I", "S")
 | 
			
		||||
    assert "senter" not in nlp.pipe_labels
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_store_activations():
 | 
			
		||||
    # Simple test to try and quickly overfit the senter - ensuring the ML models work correctly
 | 
			
		||||
    nlp = English()
 | 
			
		||||
    senter = cast(TrainablePipe, nlp.add_pipe("senter"))
 | 
			
		||||
 | 
			
		||||
    train_examples = []
 | 
			
		||||
    for t in TRAIN_DATA:
 | 
			
		||||
        train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | 
			
		||||
 | 
			
		||||
    nlp.initialize(get_examples=lambda: train_examples)
 | 
			
		||||
 | 
			
		||||
    senter.store_activations = True
 | 
			
		||||
 | 
			
		||||
    doc = nlp("This is a test.")
 | 
			
		||||
    assert "senter" in doc.activations
 | 
			
		||||
    assert set(doc.activations["senter"].keys()) == {"guesses", "probs"}
 | 
			
		||||
    assert doc.activations["senter"]["probs"].shape == (5, 2)
 | 
			
		||||
    assert doc.activations["senter"]["guesses"].shape == (5,)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,3 +1,4 @@
 | 
			
		|||
from typing import cast
 | 
			
		||||
import pytest
 | 
			
		||||
from numpy.testing import assert_equal
 | 
			
		||||
from spacy.attrs import TAG
 | 
			
		||||
| 
						 | 
				
			
			@ -6,6 +7,7 @@ from spacy import util
 | 
			
		|||
from spacy.training import Example
 | 
			
		||||
from spacy.lang.en import English
 | 
			
		||||
from spacy.language import Language
 | 
			
		||||
from spacy.pipeline import TrainablePipe
 | 
			
		||||
from thinc.api import compounding
 | 
			
		||||
 | 
			
		||||
from ..util import make_tempdir
 | 
			
		||||
| 
						 | 
				
			
			@ -211,6 +213,25 @@ def test_overfitting_IO():
 | 
			
		|||
    assert doc3[0].tag_ != "N"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_store_activations():
 | 
			
		||||
    # Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
 | 
			
		||||
    nlp = English()
 | 
			
		||||
    tagger = cast(TrainablePipe, nlp.add_pipe("tagger"))
 | 
			
		||||
    train_examples = []
 | 
			
		||||
    for t in TRAIN_DATA:
 | 
			
		||||
        train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
 | 
			
		||||
    nlp.initialize(get_examples=lambda: train_examples)
 | 
			
		||||
 | 
			
		||||
    tagger.store_activations = True
 | 
			
		||||
 | 
			
		||||
    doc = nlp("This is a test.")
 | 
			
		||||
 | 
			
		||||
    assert "tagger" in doc.activations
 | 
			
		||||
    assert set(doc.activations["tagger"].keys()) == {"guesses", "probs"}
 | 
			
		||||
    assert doc.activations["tagger"]["probs"].shape == (5, len(TAGS))
 | 
			
		||||
    assert doc.activations["tagger"]["guesses"].shape == (5,)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_tagger_requires_labels():
 | 
			
		||||
    nlp = English()
 | 
			
		||||
    nlp.add_pipe("tagger")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -50,6 +50,8 @@ cdef class Doc:
 | 
			
		|||
 | 
			
		||||
    cdef public float sentiment
 | 
			
		||||
 | 
			
		||||
    cdef public dict activations
 | 
			
		||||
 | 
			
		||||
    cdef public dict user_hooks
 | 
			
		||||
    cdef public dict user_token_hooks
 | 
			
		||||
    cdef public dict user_span_hooks
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -245,6 +245,7 @@ cdef class Doc:
 | 
			
		|||
        self.length = 0
 | 
			
		||||
        self.sentiment = 0.0
 | 
			
		||||
        self.cats = {}
 | 
			
		||||
        self.activations = {}
 | 
			
		||||
        self.user_hooks = {}
 | 
			
		||||
        self.user_token_hooks = {}
 | 
			
		||||
        self.user_span_hooks = {}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user