Store activations in Docs when save_activations is enabled (#11002)

* Store activations in Doc when `store_activations` is enabled

This change adds the new `activations` attribute to `Doc`. This
attribute can be used by trainable pipes to store their activations,
probabilities, and guesses for downstream users.

As an example, this change modifies the `tagger` and `senter` pipes to
add an `store_activations` option. When this option is enabled, the
probabilities and guesses are stored in `set_annotations`.

* Change type of `store_activations` to `Union[bool, List[str]]`

When the value is:

- A bool: all activations are stored when set to `True`.
- A List[str]: the activations named in the list are stored

* Formatting fixes in Tagger

* Support store_activations in spancat and morphologizer

* Make Doc.activations type visible to MyPy

* textcat/textcat_multilabel: add store_activations option

* trainable_lemmatizer/entity_linker: add store_activations option

* parser/ner: do not currently support returning activations

* Extend tagger and senter tests

So that they, like the other tests, also check that we get no
activations if no activations were requested.

* Document `Doc.activations` and `store_activations` in the relevant pipes

* Start errors/warnings at higher numbers to avoid merge conflicts

Between the master and v4 branches.

* Add `store_activations` to docstrings.

* Replace store_activations setter by set_store_activations method

Setters that take a different type than what the getter returns are still
problematic for MyPy. Replace the setter by a method, so that type inference
works everywhere.

* Use dict comprehension suggested by @svlandeg

* Revert "Use dict comprehension suggested by @svlandeg"

This reverts commit 6e7b958f70.

* EntityLinker: add type annotations to _add_activations

* _store_activations: make kwarg-only, remove doc_scores_lens arg

* set_annotations: add type annotations

* Apply suggestions from code review

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* TextCat.predict: return dict

* Make the `TrainablePipe.store_activations` property a bool

This means that we can also bring back `store_activations` setter.

* Remove `TrainablePipe.activations`

We do not need to enumerate the activations anymore since `store_activations` is
`bool`.

* Add type annotations for activations in predict/set_annotations

* Rename `TrainablePipe.store_activations` to `save_activations`

* Error E1400 is not used anymore

This error was used when activations were still `Union[bool, List[str]]`.

* Change wording in API docs after store -> save change

* docs: tag (save_)activations as new in spaCy 4.0

* Fix copied line in morphologizer activations test

* Don't train in any test_save_activations test

* Rename activations

- "probs" -> "probabilities"
- "guesses" -> "label_ids", except in the edit tree lemmatizer, where
  "guesses" -> "tree_ids".

* Remove unused W400 warning.

This warning was used when we still allowed the user to specify
which activations to save.

* Formatting fixes

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Replace "kb_ids" by a constant

* spancat: replace a cast by an assertion

* Fix EOF spacing

* Fix comments in test_save_activations tests

* Do not set RNG seed in activation saving tests

* Revert "spancat: replace a cast by an assertion"

This reverts commit 0bd5730d16.

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
This commit is contained in:
Daniël de Kok 2022-09-13 09:51:12 +02:00 committed by GitHub
parent 60c050e82b
commit efdbb722c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
28 changed files with 580 additions and 130 deletions

View File

@ -7,7 +7,7 @@ import numpy as np
import srsly import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy from thinc.api import Config, Model, SequenceCategoricalCrossentropy
from thinc.types import Floats2d, Ints1d, Ints2d from thinc.types import ArrayXd, Floats2d, Ints1d
from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.edit_trees import EditTrees
from ._edit_tree_internals.schemas import validate_edit_tree from ._edit_tree_internals.schemas import validate_edit_tree
@ -21,6 +21,9 @@ from ..vocab import Vocab
from .. import util from .. import util
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
default_model_config = """ default_model_config = """
[model] [model]
@architectures = "spacy.Tagger.v2" @architectures = "spacy.Tagger.v2"
@ -49,6 +52,7 @@ DEFAULT_EDIT_TREE_LEMMATIZER_MODEL = Config().from_str(default_model_config)["mo
"overwrite": False, "overwrite": False,
"top_k": 1, "top_k": 1,
"scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"}, "scorer": {"@scorers": "spacy.lemmatizer_scorer.v1"},
"save_activations": False,
}, },
default_score_weights={"lemma_acc": 1.0}, default_score_weights={"lemma_acc": 1.0},
) )
@ -61,6 +65,7 @@ def make_edit_tree_lemmatizer(
overwrite: bool, overwrite: bool,
top_k: int, top_k: int,
scorer: Optional[Callable], scorer: Optional[Callable],
save_activations: bool,
): ):
"""Construct an EditTreeLemmatizer component.""" """Construct an EditTreeLemmatizer component."""
return EditTreeLemmatizer( return EditTreeLemmatizer(
@ -72,6 +77,7 @@ def make_edit_tree_lemmatizer(
overwrite=overwrite, overwrite=overwrite,
top_k=top_k, top_k=top_k,
scorer=scorer, scorer=scorer,
save_activations=save_activations,
) )
@ -91,6 +97,7 @@ class EditTreeLemmatizer(TrainablePipe):
overwrite: bool = False, overwrite: bool = False,
top_k: int = 1, top_k: int = 1,
scorer: Optional[Callable] = lemmatizer_score, scorer: Optional[Callable] = lemmatizer_score,
save_activations: bool = False,
): ):
""" """
Construct an edit tree lemmatizer. Construct an edit tree lemmatizer.
@ -102,6 +109,7 @@ class EditTreeLemmatizer(TrainablePipe):
frequency in the training data. frequency in the training data.
overwrite (bool): overwrite existing lemma annotations. overwrite (bool): overwrite existing lemma annotations.
top_k (int): try to apply at most the k most probable edit trees. top_k (int): try to apply at most the k most probable edit trees.
save_activations (bool): save model activations in Doc when annotating.
""" """
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
@ -116,6 +124,7 @@ class EditTreeLemmatizer(TrainablePipe):
self.cfg: Dict[str, Any] = {"labels": []} self.cfg: Dict[str, Any] = {"labels": []}
self.scorer = scorer self.scorer = scorer
self.save_activations = save_activations
def get_loss( def get_loss(
self, examples: Iterable[Example], scores: List[Floats2d] self, examples: Iterable[Example], scores: List[Floats2d]
@ -144,21 +153,24 @@ class EditTreeLemmatizer(TrainablePipe):
return float(loss), d_scores return float(loss), d_scores
def predict(self, docs: Iterable[Doc]) -> List[Ints2d]: def predict(self, docs: Iterable[Doc]) -> ActivationsT:
n_docs = len(list(docs)) n_docs = len(list(docs))
if not any(len(doc) for doc in docs): if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
n_labels = len(self.cfg["labels"]) n_labels = len(self.cfg["labels"])
guesses: List[Ints2d] = [ guesses: List[Ints1d] = [
self.model.ops.alloc((0,), dtype="i") for doc in docs
]
scores: List[Floats2d] = [
self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs self.model.ops.alloc((0, n_labels), dtype="i") for doc in docs
] ]
assert len(guesses) == n_docs assert len(guesses) == n_docs
return guesses return {"probabilities": scores, "tree_ids": guesses}
scores = self.model.predict(docs) scores = self.model.predict(docs)
assert len(scores) == n_docs assert len(scores) == n_docs
guesses = self._scores2guesses(docs, scores) guesses = self._scores2guesses(docs, scores)
assert len(guesses) == n_docs assert len(guesses) == n_docs
return guesses return {"probabilities": scores, "tree_ids": guesses}
def _scores2guesses(self, docs, scores): def _scores2guesses(self, docs, scores):
guesses = [] guesses = []
@ -186,8 +198,13 @@ class EditTreeLemmatizer(TrainablePipe):
return guesses return guesses
def set_annotations(self, docs: Iterable[Doc], batch_tree_ids): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
batch_tree_ids = activations["tree_ids"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.save_activations:
doc.activations[self.name] = {}
for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tree_ids = batch_tree_ids[i] doc_tree_ids = batch_tree_ids[i]
if hasattr(doc_tree_ids, "get"): if hasattr(doc_tree_ids, "get"):
doc_tree_ids = doc_tree_ids.get() doc_tree_ids = doc_tree_ids.get()

View File

@ -1,5 +1,7 @@
from typing import Optional, Iterable, Callable, Dict, Union, List, Any from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
from thinc.types import Floats2d from typing import cast
from numpy import dtype
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from pathlib import Path from pathlib import Path
from itertools import islice from itertools import islice
import srsly import srsly
@ -21,6 +23,11 @@ from ..util import SimpleFrozenList, registry
from .. import util from .. import util
from ..scorer import Scorer from ..scorer import Scorer
ActivationsT = Dict[str, Union[List[Ragged], List[str]]]
KNOWLEDGE_BASE_IDS = "kb_ids"
# See #9050 # See #9050
BACKWARD_OVERWRITE = True BACKWARD_OVERWRITE = True
@ -57,6 +64,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True, "use_gold_ents": True,
"threshold": None, "threshold": None,
"save_activations": False,
}, },
default_score_weights={ default_score_weights={
"nel_micro_f": 1.0, "nel_micro_f": 1.0,
@ -79,6 +87,7 @@ def make_entity_linker(
scorer: Optional[Callable], scorer: Optional[Callable],
use_gold_ents: bool, use_gold_ents: bool,
threshold: Optional[float] = None, threshold: Optional[float] = None,
save_activations: bool,
): ):
"""Construct an EntityLinker component. """Construct an EntityLinker component.
@ -97,6 +106,7 @@ def make_entity_linker(
component must provide entity annotations. component must provide entity annotations.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
prediction is discarded. If None, predictions are not filtered by any threshold. prediction is discarded. If None, predictions are not filtered by any threshold.
save_activations (bool): save model activations in Doc when annotating.
""" """
if not model.attrs.get("include_span_maker", False): if not model.attrs.get("include_span_maker", False):
@ -128,6 +138,7 @@ def make_entity_linker(
scorer=scorer, scorer=scorer,
use_gold_ents=use_gold_ents, use_gold_ents=use_gold_ents,
threshold=threshold, threshold=threshold,
save_activations=save_activations,
) )
@ -164,6 +175,7 @@ class EntityLinker(TrainablePipe):
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool, use_gold_ents: bool,
threshold: Optional[float] = None, threshold: Optional[float] = None,
save_activations: bool = False,
) -> None: ) -> None:
"""Initialize an entity linker. """Initialize an entity linker.
@ -212,6 +224,7 @@ class EntityLinker(TrainablePipe):
self.scorer = scorer self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.threshold = threshold self.threshold = threshold
self.save_activations = save_activations
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will """Define the KB of this pipe by providing a function that will
@ -397,7 +410,7 @@ class EntityLinker(TrainablePipe):
loss = loss / len(entity_encodings) loss = loss / len(entity_encodings)
return float(loss), out return float(loss), out
def predict(self, docs: Iterable[Doc]) -> List[str]: def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
Returns the KB IDs for each entity in each doc, including NIL if there is Returns the KB IDs for each entity in each doc, including NIL if there is
no prediction. no prediction.
@ -410,13 +423,20 @@ class EntityLinker(TrainablePipe):
self.validate_kb() self.validate_kb()
entity_count = 0 entity_count = 0
final_kb_ids: List[str] = [] final_kb_ids: List[str] = []
xp = self.model.ops.xp ops = self.model.ops
xp = ops.xp
docs_ents: List[Ragged] = []
docs_scores: List[Ragged] = []
if not docs: if not docs:
return final_kb_ids return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for i, doc in enumerate(docs): for doc in docs:
doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = []
if len(doc) == 0: if len(doc) == 0:
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
continue continue
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
# Looping through each entity (TODO: rewrite) # Looping through each entity (TODO: rewrite)
@ -439,14 +459,32 @@ class EntityLinker(TrainablePipe):
if ent.label_ in self.labels_discard: if ent.label_ in self.labels_discard:
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[0.0],
ents=[0],
)
else: else:
candidates = list(self.get_candidates(self.kb, ent)) candidates = list(self.get_candidates(self.kb, ent))
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[0.0],
ents=[0],
)
elif len(candidates) == 1 and self.threshold is None: elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_)
self._add_activations(
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[1.0],
ents=[candidates[0].entity_],
)
else: else:
random.shuffle(candidates) random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False # set all prior probabilities to 0 if incl_prior=False
@ -479,27 +517,48 @@ class EntityLinker(TrainablePipe):
if self.threshold is None or scores.max() >= self.threshold if self.threshold is None or scores.max() >= self.threshold
else EntityLinker.NIL else EntityLinker.NIL
) )
self._add_activations(
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=scores,
ents=[c.entity for c in candidates],
)
self._add_doc_activations(
docs_scores=docs_scores,
docs_ents=docs_ents,
doc_scores=doc_scores,
doc_ents=doc_ents,
)
if not (len(final_kb_ids) == entity_count): if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format( err = Errors.E147.format(
method="predict", msg="result variables not of equal length" method="predict", msg="result variables not of equal length"
) )
raise RuntimeError(err) raise RuntimeError(err)
return final_kb_ids return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None: def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict. activations (ActivationsT): The activations used for setting annotations, produced
by EntityLinker.predict.
DOCS: https://spacy.io/api/entitylinker#set_annotations DOCS: https://spacy.io/api/entitylinker#set_annotations
""" """
kb_ids = cast(List[str], activations[KNOWLEDGE_BASE_IDS])
count_ents = len([ent for doc in docs for ent in doc.ents]) count_ents = len([ent for doc in docs for ent in doc.ents])
if count_ents != len(kb_ids): if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids))) raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
i = 0 i = 0
overwrite = self.cfg["overwrite"] overwrite = self.cfg["overwrite"]
for doc in docs: for j, doc in enumerate(docs):
if self.save_activations:
doc.activations[self.name] = {}
for act_name, acts in activations.items():
if act_name != KNOWLEDGE_BASE_IDS:
# We only copy activations that are Ragged.
doc.activations[self.name][act_name] = cast(Ragged, acts[j])
for ent in doc.ents: for ent in doc.ents:
kb_id = kb_ids[i] kb_id = kb_ids[i]
i += 1 i += 1
@ -598,3 +657,32 @@ class EntityLinker(TrainablePipe):
def add_label(self, label): def add_label(self, label):
raise NotImplementedError raise NotImplementedError
def _add_doc_activations(
self,
*,
docs_scores: List[Ragged],
docs_ents: List[Ragged],
doc_scores: List[Floats1d],
doc_ents: List[Ints1d],
):
if not self.save_activations:
return
ops = self.model.ops
lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
docs_scores.append(Ragged(ops.flatten(doc_scores), lengths))
docs_ents.append(Ragged(ops.flatten(doc_ents), lengths))
def _add_activations(
self,
*,
doc_scores: List[Floats1d],
doc_ents: List[Ints1d],
scores: Sequence[float],
ents: Sequence[int],
):
if not self.save_activations:
return
ops = self.model.ops
doc_scores.append(ops.asarray1f(scores))
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))

View File

@ -1,7 +1,8 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Optional, Union, Dict, Callable from typing import Callable, Dict, Iterable, List, Optional, Union
import srsly import srsly
from thinc.api import SequenceCategoricalCrossentropy, Model, Config from thinc.api import SequenceCategoricalCrossentropy, Model, Config
from thinc.types import Floats2d, Ints1d
from itertools import islice from itertools import islice
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -13,7 +14,7 @@ from ..symbols import POS
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from .pipe import deserialize_config from .pipe import deserialize_config
from .tagger import Tagger from .tagger import ActivationsT, Tagger
from .. import util from .. import util
from ..scorer import Scorer from ..scorer import Scorer
from ..training import validate_examples, validate_get_examples from ..training import validate_examples, validate_get_examples
@ -52,7 +53,13 @@ DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"morphologizer", "morphologizer",
assigns=["token.morph", "token.pos"], assigns=["token.morph", "token.pos"],
default_config={"model": DEFAULT_MORPH_MODEL, "overwrite": True, "extend": False, "scorer": {"@scorers": "spacy.morphologizer_scorer.v1"}}, default_config={
"model": DEFAULT_MORPH_MODEL,
"overwrite": True,
"extend": False,
"scorer": {"@scorers": "spacy.morphologizer_scorer.v1"},
"save_activations": False,
},
default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None}, default_score_weights={"pos_acc": 0.5, "morph_acc": 0.5, "morph_per_feat": None},
) )
def make_morphologizer( def make_morphologizer(
@ -62,8 +69,10 @@ def make_morphologizer(
overwrite: bool, overwrite: bool,
extend: bool, extend: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
save_activations: bool,
): ):
return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer) return Morphologizer(nlp.vocab, model, name, overwrite=overwrite, extend=extend, scorer=scorer,
save_activations=save_activations)
def morphologizer_score(examples, **kwargs): def morphologizer_score(examples, **kwargs):
@ -95,6 +104,7 @@ class Morphologizer(Tagger):
overwrite: bool = BACKWARD_OVERWRITE, overwrite: bool = BACKWARD_OVERWRITE,
extend: bool = BACKWARD_EXTEND, extend: bool = BACKWARD_EXTEND,
scorer: Optional[Callable] = morphologizer_score, scorer: Optional[Callable] = morphologizer_score,
save_activations: bool = False,
): ):
"""Initialize a morphologizer. """Initialize a morphologizer.
@ -105,6 +115,7 @@ class Morphologizer(Tagger):
scorer (Optional[Callable]): The scoring method. Defaults to scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_token_attr for the attributes "pos" and "morph" and Scorer.score_token_attr for the attributes "pos" and "morph" and
Scorer.score_token_attr_per_feat for the attribute "morph". Scorer.score_token_attr_per_feat for the attribute "morph".
save_activations (bool): save model activations in Doc when annotating.
DOCS: https://spacy.io/api/morphologizer#init DOCS: https://spacy.io/api/morphologizer#init
""" """
@ -124,6 +135,7 @@ class Morphologizer(Tagger):
} }
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer self.scorer = scorer
self.save_activations = save_activations
@property @property
def labels(self): def labels(self):
@ -217,14 +229,15 @@ class Morphologizer(Tagger):
assert len(label_sample) > 0, Errors.E923.format(name=self.name) assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample) self.model.initialize(X=doc_sample, Y=label_sample)
def set_annotations(self, docs, batch_tag_ids): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
batch_tag_ids: The IDs to set, produced by Morphologizer.predict. activations (ActivationsT): The activations used for setting annotations, produced by Morphologizer.predict.
DOCS: https://spacy.io/api/morphologizer#set_annotations DOCS: https://spacy.io/api/morphologizer#set_annotations
""" """
batch_tag_ids = activations["label_ids"]
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
@ -236,6 +249,10 @@ class Morphologizer(Tagger):
# to allocate a compatible container out of the iterable. # to allocate a compatible container out of the iterable.
labels = tuple(self.labels) labels = tuple(self.labels)
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.save_activations:
doc.activations[self.name] = {}
for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -1,13 +1,14 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Optional, Callable from typing import Dict, Iterable, Optional, Callable, List, Union
from itertools import islice from itertools import islice
import srsly import srsly
from thinc.api import Model, SequenceCategoricalCrossentropy, Config from thinc.api import Model, SequenceCategoricalCrossentropy, Config
from thinc.types import Floats2d, Ints1d
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from .tagger import Tagger from .tagger import ActivationsT, Tagger
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from ..scorer import Scorer from ..scorer import Scorer
@ -38,11 +39,21 @@ DEFAULT_SENTER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"senter", "senter",
assigns=["token.is_sent_start"], assigns=["token.is_sent_start"],
default_config={"model": DEFAULT_SENTER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.senter_scorer.v1"}}, default_config={
"model": DEFAULT_SENTER_MODEL,
"overwrite": False,
"scorer": {"@scorers": "spacy.senter_scorer.v1"},
"save_activations": False,
},
default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0}, default_score_weights={"sents_f": 1.0, "sents_p": 0.0, "sents_r": 0.0},
) )
def make_senter(nlp: Language, name: str, model: Model, overwrite: bool, scorer: Optional[Callable]): def make_senter(nlp: Language,
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer) name: str,
model: Model,
overwrite: bool,
scorer: Optional[Callable],
save_activations: bool):
return SentenceRecognizer(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, save_activations=save_activations)
def senter_score(examples, **kwargs): def senter_score(examples, **kwargs):
@ -72,6 +83,7 @@ class SentenceRecognizer(Tagger):
*, *,
overwrite=BACKWARD_OVERWRITE, overwrite=BACKWARD_OVERWRITE,
scorer=senter_score, scorer=senter_score,
save_activations: bool = False,
): ):
"""Initialize a sentence recognizer. """Initialize a sentence recognizer.
@ -81,6 +93,7 @@ class SentenceRecognizer(Tagger):
losses during training. losses during training.
scorer (Optional[Callable]): The scoring method. Defaults to scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_spans for the attribute "sents". Scorer.score_spans for the attribute "sents".
save_activations (bool): save model activations in Doc when annotating.
DOCS: https://spacy.io/api/sentencerecognizer#init DOCS: https://spacy.io/api/sentencerecognizer#init
""" """
@ -90,6 +103,7 @@ class SentenceRecognizer(Tagger):
self._rehearsal_model = None self._rehearsal_model = None
self.cfg = {"overwrite": overwrite} self.cfg = {"overwrite": overwrite}
self.scorer = scorer self.scorer = scorer
self.save_activations = save_activations
@property @property
def labels(self): def labels(self):
@ -107,19 +121,24 @@ class SentenceRecognizer(Tagger):
def label_data(self): def label_data(self):
return None return None
def set_annotations(self, docs, batch_tag_ids): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
batch_tag_ids: The IDs to set, produced by SentenceRecognizer.predict. activations (ActivationsT): The activations used for setting annotations, produced by SentenceRecognizer.predict.
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
""" """
batch_tag_ids = activations["label_ids"]
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.save_activations:
doc.activations[self.name] = {}
for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -1,4 +1,5 @@
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from typing import Union
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
@ -16,6 +17,9 @@ from ..errors import Errors
from ..util import registry from ..util import registry
ActivationsT = Dict[str, Union[Floats2d, Ragged]]
spancat_default_config = """ spancat_default_config = """
[model] [model]
@architectures = "spacy.SpanCategorizer.v1" @architectures = "spacy.SpanCategorizer.v1"
@ -106,6 +110,7 @@ def build_ngram_range_suggester(min_size: int, max_size: int) -> Suggester:
"model": DEFAULT_SPANCAT_MODEL, "model": DEFAULT_SPANCAT_MODEL,
"suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]}, "suggester": {"@misc": "spacy.ngram_suggester.v1", "sizes": [1, 2, 3]},
"scorer": {"@scorers": "spacy.spancat_scorer.v1"}, "scorer": {"@scorers": "spacy.spancat_scorer.v1"},
"save_activations": False,
}, },
default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0}, default_score_weights={"spans_sc_f": 1.0, "spans_sc_p": 0.0, "spans_sc_r": 0.0},
) )
@ -118,6 +123,7 @@ def make_spancat(
scorer: Optional[Callable], scorer: Optional[Callable],
threshold: float, threshold: float,
max_positive: Optional[int], max_positive: Optional[int],
save_activations: bool,
) -> "SpanCategorizer": ) -> "SpanCategorizer":
"""Create a SpanCategorizer component. The span categorizer consists of two """Create a SpanCategorizer component. The span categorizer consists of two
parts: a suggester function that proposes candidate spans, and a labeller parts: a suggester function that proposes candidate spans, and a labeller
@ -138,6 +144,7 @@ def make_spancat(
0.5. 0.5.
max_positive (Optional[int]): Maximum number of labels to consider positive max_positive (Optional[int]): Maximum number of labels to consider positive
per span. Defaults to None, indicating no limit. per span. Defaults to None, indicating no limit.
save_activations (bool): save model activations in Doc when annotating.
""" """
return SpanCategorizer( return SpanCategorizer(
nlp.vocab, nlp.vocab,
@ -148,6 +155,7 @@ def make_spancat(
max_positive=max_positive, max_positive=max_positive,
name=name, name=name,
scorer=scorer, scorer=scorer,
save_activations=save_activations,
) )
@ -186,6 +194,7 @@ class SpanCategorizer(TrainablePipe):
threshold: float = 0.5, threshold: float = 0.5,
max_positive: Optional[int] = None, max_positive: Optional[int] = None,
scorer: Optional[Callable] = spancat_score, scorer: Optional[Callable] = spancat_score,
save_activations: bool = False,
) -> None: ) -> None:
"""Initialize the span categorizer. """Initialize the span categorizer.
vocab (Vocab): The shared vocabulary. vocab (Vocab): The shared vocabulary.
@ -218,6 +227,7 @@ class SpanCategorizer(TrainablePipe):
self.model = model self.model = model
self.name = name self.name = name
self.scorer = scorer self.scorer = scorer
self.save_activations = save_activations
@property @property
def key(self) -> str: def key(self) -> str:
@ -260,7 +270,7 @@ class SpanCategorizer(TrainablePipe):
""" """
return list(self.labels) return list(self.labels)
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
@ -270,7 +280,7 @@ class SpanCategorizer(TrainablePipe):
""" """
indices = self.suggester(docs, ops=self.model.ops) indices = self.suggester(docs, ops=self.model.ops)
scores = self.model.predict((docs, indices)) # type: ignore scores = self.model.predict((docs, indices)) # type: ignore
return indices, scores return {"indices": indices, "scores": scores}
def set_candidates( def set_candidates(
self, docs: Iterable[Doc], *, candidates_key: str = "candidates" self, docs: Iterable[Doc], *, candidates_key: str = "candidates"
@ -290,19 +300,29 @@ class SpanCategorizer(TrainablePipe):
for index in candidates.dataXd: for index in candidates.dataXd:
doc.spans[candidates_key].append(doc[index[0] : index[1]]) doc.spans[candidates_key].append(doc[index[0] : index[1]])
def set_annotations(self, docs: Iterable[Doc], indices_scores) -> None: def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of Doc objects, using pre-computed scores. """Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
scores: The scores to set, produced by SpanCategorizer.predict. activations: ActivationsT: The activations, produced by SpanCategorizer.predict.
DOCS: https://spacy.io/api/spancategorizer#set_annotations DOCS: https://spacy.io/api/spancategorizer#set_annotations
""" """
labels = self.labels labels = self.labels
indices, scores = indices_scores
indices = activations["indices"]
assert isinstance(indices, Ragged)
scores = cast(Floats2d, activations["scores"])
offset = 0 offset = 0
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
indices_i = indices[i].dataXd indices_i = indices[i].dataXd
if self.save_activations:
doc.activations[self.name] = {}
doc.activations[self.name]["indices"] = indices_i
doc.activations[self.name]["scores"] = scores[
offset : offset + indices.lengths[i]
]
doc.spans[self.key] = self._make_span_group( doc.spans[self.key] = self._make_span_group(
doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type] doc, indices_i, scores[offset : offset + indices.lengths[i]], labels # type: ignore[arg-type]
) )

View File

@ -1,9 +1,9 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Callable, Optional from typing import Callable, Dict, Iterable, List, Optional, Union
import numpy import numpy
import srsly import srsly
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config
from thinc.types import Floats2d from thinc.types import Floats2d, Ints1d
import warnings import warnings
from itertools import islice from itertools import islice
@ -22,6 +22,9 @@ from ..training import validate_examples, validate_get_examples
from ..util import registry from ..util import registry
from .. import util from .. import util
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
# See #9050 # See #9050
BACKWARD_OVERWRITE = False BACKWARD_OVERWRITE = False
@ -45,7 +48,13 @@ DEFAULT_TAGGER_MODEL = Config().from_str(default_model_config)["model"]
@Language.factory( @Language.factory(
"tagger", "tagger",
assigns=["token.tag"], assigns=["token.tag"],
default_config={"model": DEFAULT_TAGGER_MODEL, "overwrite": False, "scorer": {"@scorers": "spacy.tagger_scorer.v1"}, "neg_prefix": "!"}, default_config={
"model": DEFAULT_TAGGER_MODEL,
"overwrite": False,
"scorer": {"@scorers": "spacy.tagger_scorer.v1"},
"neg_prefix": "!",
"save_activations": False,
},
default_score_weights={"tag_acc": 1.0}, default_score_weights={"tag_acc": 1.0},
) )
def make_tagger( def make_tagger(
@ -55,6 +64,7 @@ def make_tagger(
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
neg_prefix: str, neg_prefix: str,
save_activations: bool,
): ):
"""Construct a part-of-speech tagger component. """Construct a part-of-speech tagger component.
@ -63,7 +73,8 @@ def make_tagger(
in size, and be normalized as probabilities (all scores between 0 and 1, in size, and be normalized as probabilities (all scores between 0 and 1,
with the rows summing to 1). with the rows summing to 1).
""" """
return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix) return Tagger(nlp.vocab, model, name, overwrite=overwrite, scorer=scorer, neg_prefix=neg_prefix,
save_activations=save_activations)
def tagger_score(examples, **kwargs): def tagger_score(examples, **kwargs):
@ -89,6 +100,7 @@ class Tagger(TrainablePipe):
overwrite=BACKWARD_OVERWRITE, overwrite=BACKWARD_OVERWRITE,
scorer=tagger_score, scorer=tagger_score,
neg_prefix="!", neg_prefix="!",
save_activations: bool = False,
): ):
"""Initialize a part-of-speech tagger. """Initialize a part-of-speech tagger.
@ -98,6 +110,7 @@ class Tagger(TrainablePipe):
losses during training. losses during training.
scorer (Optional[Callable]): The scoring method. Defaults to scorer (Optional[Callable]): The scoring method. Defaults to
Scorer.score_token_attr for the attribute "tag". Scorer.score_token_attr for the attribute "tag".
save_activations (bool): save model activations in Doc when annotating.
DOCS: https://spacy.io/api/tagger#init DOCS: https://spacy.io/api/tagger#init
""" """
@ -108,6 +121,7 @@ class Tagger(TrainablePipe):
cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix} cfg = {"labels": [], "overwrite": overwrite, "neg_prefix": neg_prefix}
self.cfg = dict(sorted(cfg.items())) self.cfg = dict(sorted(cfg.items()))
self.scorer = scorer self.scorer = scorer
self.save_activations = save_activations
@property @property
def labels(self): def labels(self):
@ -126,7 +140,7 @@ class Tagger(TrainablePipe):
"""Data about the labels currently added to the component.""" """Data about the labels currently added to the component."""
return tuple(self.cfg["labels"]) return tuple(self.cfg["labels"])
def predict(self, docs): def predict(self, docs) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
@ -139,12 +153,12 @@ class Tagger(TrainablePipe):
n_labels = len(self.labels) n_labels = len(self.labels)
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs] guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
assert len(guesses) == len(docs) assert len(guesses) == len(docs)
return guesses return {"probabilities": guesses, "label_ids": guesses}
scores = self.model.predict(docs) scores = self.model.predict(docs)
assert len(scores) == len(docs), (len(scores), len(docs)) assert len(scores) == len(docs), (len(scores), len(docs))
guesses = self._scores2guesses(scores) guesses = self._scores2guesses(scores)
assert len(guesses) == len(docs) assert len(guesses) == len(docs)
return guesses return {"probabilities": scores, "label_ids": guesses}
def _scores2guesses(self, scores): def _scores2guesses(self, scores):
guesses = [] guesses = []
@ -155,14 +169,15 @@ class Tagger(TrainablePipe):
guesses.append(doc_guesses) guesses.append(doc_guesses)
return guesses return guesses
def set_annotations(self, docs, batch_tag_ids): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
batch_tag_ids: The IDs to set, produced by Tagger.predict. activations (ActivationsT): The activations used for setting annotations, produced by Tagger.predict.
DOCS: https://spacy.io/api/tagger#set_annotations DOCS: https://spacy.io/api/tagger#set_annotations
""" """
batch_tag_ids = activations["label_ids"]
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
cdef Doc doc cdef Doc doc
@ -170,6 +185,10 @@ class Tagger(TrainablePipe):
cdef bint overwrite = self.cfg["overwrite"] cdef bint overwrite = self.cfg["overwrite"]
labels = self.labels labels = self.labels
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.save_activations:
doc.activations[self.name] = {}
for act_name, acts in activations.items():
doc.activations[self.name][act_name] = acts[i]
doc_tag_ids = batch_tag_ids[i] doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"): if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get() doc_tag_ids = doc_tag_ids.get()

View File

@ -1,4 +1,4 @@
from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any from typing import Iterable, Tuple, Optional, Dict, List, Callable, Any, Union
from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config from thinc.api import get_array_module, Model, Optimizer, set_dropout_rate, Config
from thinc.types import Floats2d from thinc.types import Floats2d
import numpy import numpy
@ -14,6 +14,9 @@ from ..util import registry
from ..vocab import Vocab from ..vocab import Vocab
ActivationsT = Dict[str, Floats2d]
single_label_default_config = """ single_label_default_config = """
[model] [model]
@architectures = "spacy.TextCatEnsemble.v2" @architectures = "spacy.TextCatEnsemble.v2"
@ -75,6 +78,7 @@ subword_features = true
"threshold": 0.5, "threshold": 0.5,
"model": DEFAULT_SINGLE_TEXTCAT_MODEL, "model": DEFAULT_SINGLE_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_scorer.v1"}, "scorer": {"@scorers": "spacy.textcat_scorer.v1"},
"save_activations": False,
}, },
default_score_weights={ default_score_weights={
"cats_score": 1.0, "cats_score": 1.0,
@ -96,6 +100,7 @@ def make_textcat(
model: Model[List[Doc], List[Floats2d]], model: Model[List[Doc], List[Floats2d]],
threshold: float, threshold: float,
scorer: Optional[Callable], scorer: Optional[Callable],
save_activations: bool,
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories """Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered over a whole document. It can learn one or more labels, and the labels are considered
@ -105,8 +110,16 @@ def make_textcat(
scores for each category. scores for each category.
threshold (float): Cutoff to consider a prediction "positive". threshold (float): Cutoff to consider a prediction "positive".
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
save_activations (bool): save model activations in Doc when annotating.
""" """
return TextCategorizer(nlp.vocab, model, name, threshold=threshold, scorer=scorer) return TextCategorizer(
nlp.vocab,
model,
name,
threshold=threshold,
scorer=scorer,
save_activations=save_activations,
)
def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]: def textcat_score(examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
@ -137,6 +150,7 @@ class TextCategorizer(TrainablePipe):
*, *,
threshold: float, threshold: float,
scorer: Optional[Callable] = textcat_score, scorer: Optional[Callable] = textcat_score,
save_activations: bool = False,
) -> None: ) -> None:
"""Initialize a text categorizer for single-label classification. """Initialize a text categorizer for single-label classification.
@ -157,6 +171,7 @@ class TextCategorizer(TrainablePipe):
cfg = {"labels": [], "threshold": threshold, "positive_label": None} cfg = {"labels": [], "threshold": threshold, "positive_label": None}
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
self.save_activations = save_activations
@property @property
def support_missing_values(self): def support_missing_values(self):
@ -181,7 +196,7 @@ class TextCategorizer(TrainablePipe):
""" """
return self.labels # type: ignore[return-value] return self.labels # type: ignore[return-value]
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
@ -194,12 +209,12 @@ class TextCategorizer(TrainablePipe):
tensors = [doc.tensor for doc in docs] tensors = [doc.tensor for doc in docs]
xp = self.model.ops.xp xp = self.model.ops.xp
scores = xp.zeros((len(list(docs)), len(self.labels))) scores = xp.zeros((len(list(docs)), len(self.labels)))
return scores return {"probabilities": scores}
scores = self.model.predict(docs) scores = self.model.predict(docs)
scores = self.model.ops.asarray(scores) scores = self.model.ops.asarray(scores)
return scores return {"probabilities": scores}
def set_annotations(self, docs: Iterable[Doc], scores) -> None: def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of Doc objects, using pre-computed scores. """Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
@ -207,9 +222,13 @@ class TextCategorizer(TrainablePipe):
DOCS: https://spacy.io/api/textcategorizer#set_annotations DOCS: https://spacy.io/api/textcategorizer#set_annotations
""" """
probs = activations["probabilities"]
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if self.save_activations:
doc.activations[self.name] = {}
doc.activations[self.name]["probabilities"] = probs[i]
for j, label in enumerate(self.labels): for j, label in enumerate(self.labels):
doc.cats[label] = float(scores[i, j]) doc.cats[label] = float(probs[i, j])
def update( def update(
self, self,

View File

@ -1,4 +1,4 @@
from typing import Iterable, Optional, Dict, List, Callable, Any from typing import Iterable, Optional, Dict, List, Callable, Any, Union
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import Model, Config from thinc.api import Model, Config
@ -75,6 +75,7 @@ subword_features = true
"threshold": 0.5, "threshold": 0.5,
"model": DEFAULT_MULTI_TEXTCAT_MODEL, "model": DEFAULT_MULTI_TEXTCAT_MODEL,
"scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"}, "scorer": {"@scorers": "spacy.textcat_multilabel_scorer.v1"},
"save_activations": False,
}, },
default_score_weights={ default_score_weights={
"cats_score": 1.0, "cats_score": 1.0,
@ -96,6 +97,7 @@ def make_multilabel_textcat(
model: Model[List[Doc], List[Floats2d]], model: Model[List[Doc], List[Floats2d]],
threshold: float, threshold: float,
scorer: Optional[Callable], scorer: Optional[Callable],
save_activations: bool,
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories """Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered over a whole document. It can learn one or more labels, and the labels are considered
@ -107,7 +109,12 @@ def make_multilabel_textcat(
threshold (float): Cutoff to consider a prediction "positive". threshold (float): Cutoff to consider a prediction "positive".
""" """
return MultiLabel_TextCategorizer( return MultiLabel_TextCategorizer(
nlp.vocab, model, name, threshold=threshold, scorer=scorer nlp.vocab,
model,
name,
threshold=threshold,
scorer=scorer,
save_activations=save_activations,
) )
@ -139,6 +146,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
*, *,
threshold: float, threshold: float,
scorer: Optional[Callable] = textcat_multilabel_score, scorer: Optional[Callable] = textcat_multilabel_score,
save_activations: bool = False,
) -> None: ) -> None:
"""Initialize a text categorizer for multi-label classification. """Initialize a text categorizer for multi-label classification.
@ -147,6 +155,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
threshold (float): Cutoff to consider a prediction "positive". threshold (float): Cutoff to consider a prediction "positive".
save_activations (bool): save model activations in Doc when annotating.
DOCS: https://spacy.io/api/textcategorizer#init DOCS: https://spacy.io/api/textcategorizer#init
""" """
@ -157,6 +166,7 @@ class MultiLabel_TextCategorizer(TextCategorizer):
cfg = {"labels": [], "threshold": threshold} cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.scorer = scorer self.scorer = scorer
self.save_activations = save_activations
@property @property
def support_missing_values(self): def support_missing_values(self):

View File

@ -6,3 +6,4 @@ cdef class TrainablePipe(Pipe):
cdef public object model cdef public object model
cdef public object cfg cdef public object cfg
cdef public object scorer cdef public object scorer
cdef bint _save_activations

View File

@ -2,11 +2,12 @@
from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable
import srsly import srsly
from thinc.api import set_dropout_rate, Model, Optimizer from thinc.api import set_dropout_rate, Model, Optimizer
import warnings
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..training import validate_examples from ..training import validate_examples
from ..errors import Errors from ..errors import Errors, Warnings
from .pipe import Pipe, deserialize_config from .pipe import Pipe, deserialize_config
from .. import util from .. import util
from ..vocab import Vocab from ..vocab import Vocab
@ -342,3 +343,11 @@ cdef class TrainablePipe(Pipe):
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self
@property
def save_activations(self):
return self._save_activations
@save_activations.setter
def save_activations(self, save_activations: bool):
self._save_activations = save_activations

View File

@ -1,3 +1,4 @@
from typing import cast
import pickle import pickle
import pytest import pytest
from hypothesis import given from hypothesis import given
@ -6,6 +7,7 @@ from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees from spacy.pipeline._edit_tree_internals.edit_trees import EditTrees
from spacy.pipeline.trainable_pipe import TrainablePipe
from spacy.training import Example from spacy.training import Example
from spacy.strings import StringStore from spacy.strings import StringStore
from spacy.util import make_tempdir from spacy.util import make_tempdir
@ -278,3 +280,26 @@ def test_empty_strings():
no_change = trees.add("xyz", "xyz") no_change = trees.add("xyz", "xyz")
empty = trees.add("", "") empty = trees.add("", "")
assert no_change == empty assert no_change == empty
def test_save_activations():
nlp = English()
lemmatizer = cast(TrainablePipe, nlp.add_pipe("trainable_lemmatizer"))
lemmatizer.min_tree_freq = 1
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
nO = lemmatizer.model.get_dim("nO")
doc = nlp("This is a test.")
assert "trainable_lemmatizer" not in doc.activations
lemmatizer.save_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["trainable_lemmatizer"].keys()) == [
"probabilities",
"tree_ids",
]
assert doc.activations["trainable_lemmatizer"]["probabilities"].shape == (5, nO)
assert doc.activations["trainable_lemmatizer"]["tree_ids"].shape == (5,)

View File

@ -1,7 +1,8 @@
from typing import Callable, Iterable, Dict, Any from typing import Callable, Iterable, Dict, Any, cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from thinc.types import Ragged
from spacy import registry, util from spacy import registry, util
from spacy.attrs import ENT_KB_ID from spacy.attrs import ENT_KB_ID
@ -9,7 +10,7 @@ from spacy.compat import pickle
from spacy.kb import Candidate, KnowledgeBase, get_candidates from spacy.kb import Candidate, KnowledgeBase, get_candidates
from spacy.lang.en import English from spacy.lang.en import English
from spacy.ml import load_kb from spacy.ml import load_kb
from spacy.pipeline import EntityLinker from spacy.pipeline import EntityLinker, TrainablePipe
from spacy.pipeline.legacy import EntityLinker_v1 from spacy.pipeline.legacy import EntityLinker_v1
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer from spacy.scorer import Scorer
@ -1176,3 +1177,66 @@ def test_threshold(meet_threshold: bool, config: Dict[str, Any]):
assert len(doc.ents) == 1 assert len(doc.ents) == 1
assert doc.ents[0].kb_id_ == entity_id if meet_threshold else EntityLinker.NIL assert doc.ents[0].kb_id_ == entity_id if meet_threshold else EntityLinker.NIL
def test_save_activations():
nlp = English()
vector_length = 3
assert "Q2146908" not in nlp.vocab.strings
# Convert the texts to docs to make sure we have doc.ents set for the training examples
train_examples = []
for text, annotation in TRAIN_DATA:
doc = nlp(text)
train_examples.append(Example.from_dict(doc, annotation))
def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
alias="Russ Cochran",
entities=["Q2146908", "Q7381115"],
probabilities=[0.5, 0.5],
)
return mykb
# Create the Entity Linker component and add it to the pipeline
entity_linker = cast(TrainablePipe, nlp.add_pipe("entity_linker", last=True))
assert isinstance(entity_linker, EntityLinker)
entity_linker.set_kb(create_kb)
assert "Q2146908" in entity_linker.vocab.strings
assert "Q2146908" in entity_linker.kb.vocab.strings
# initialize the NEL pipe
nlp.initialize(get_examples=lambda: train_examples)
nO = entity_linker.model.get_dim("nO")
nlp.add_pipe("sentencizer", first=True)
patterns = [
{"label": "PERSON", "pattern": [{"LOWER": "russ"}, {"LOWER": "cochran"}]},
{"label": "ORG", "pattern": [{"LOWER": "ec"}, {"LOWER": "comics"}]},
]
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
ruler.add_patterns(patterns)
doc = nlp("Russ Cochran was a publisher")
assert "entity_linker" not in doc.activations
entity_linker.save_activations = True
doc = nlp("Russ Cochran was a publisher")
assert set(doc.activations["entity_linker"].keys()) == {"ents", "scores"}
ents = doc.activations["entity_linker"]["ents"]
assert isinstance(ents, Ragged)
assert ents.data.shape == (2, 1)
assert ents.data.dtype == "uint64"
assert ents.lengths.shape == (1,)
scores = doc.activations["entity_linker"]["scores"]
assert isinstance(scores, Ragged)
assert scores.data.shape == (2, 1)
assert scores.data.dtype == "float32"
assert scores.lengths.shape == (1,)

View File

@ -1,3 +1,4 @@
from typing import cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
@ -7,6 +8,7 @@ from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.tests.util import make_tempdir from spacy.tests.util import make_tempdir
from spacy.morphology import Morphology from spacy.morphology import Morphology
from spacy.pipeline import TrainablePipe
from spacy.attrs import MORPH from spacy.attrs import MORPH
from spacy.tokens import Doc from spacy.tokens import Doc
@ -197,3 +199,25 @@ def test_overfitting_IO():
gold_pos_tags = ["NOUN", "NOUN", "NOUN", "NOUN"] gold_pos_tags = ["NOUN", "NOUN", "NOUN", "NOUN"]
assert [str(t.morph) for t in doc] == gold_morphs assert [str(t.morph) for t in doc] == gold_morphs
assert [t.pos_ for t in doc] == gold_pos_tags assert [t.pos_ for t in doc] == gold_pos_tags
def test_save_activations():
nlp = English()
morphologizer = cast(TrainablePipe, nlp.add_pipe("morphologizer"))
train_examples = []
for inst in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(inst[0]), inst[1]))
nlp.initialize(get_examples=lambda: train_examples)
doc = nlp("This is a test.")
assert "morphologizer" not in doc.activations
morphologizer.save_activations = True
doc = nlp("This is a test.")
assert "morphologizer" in doc.activations
assert set(doc.activations["morphologizer"].keys()) == {
"label_ids",
"probabilities",
}
assert doc.activations["morphologizer"]["probabilities"].shape == (5, 6)
assert doc.activations["morphologizer"]["label_ids"].shape == (5,)

View File

@ -1,3 +1,4 @@
from typing import cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from spacy.attrs import SENT_START from spacy.attrs import SENT_START
@ -6,6 +7,7 @@ from spacy import util
from spacy.training import Example from spacy.training import Example
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TrainablePipe
from spacy.tests.util import make_tempdir from spacy.tests.util import make_tempdir
@ -101,3 +103,26 @@ def test_overfitting_IO():
# test internal pipe labels vs. Language.pipe_labels with hidden labels # test internal pipe labels vs. Language.pipe_labels with hidden labels
assert nlp.get_pipe("senter").labels == ("I", "S") assert nlp.get_pipe("senter").labels == ("I", "S")
assert "senter" not in nlp.pipe_labels assert "senter" not in nlp.pipe_labels
def test_save_activations():
# Test if activations are correctly added to Doc when requested.
nlp = English()
senter = cast(TrainablePipe, nlp.add_pipe("senter"))
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
nO = senter.model.get_dim("nO")
doc = nlp("This is a test.")
assert "senter" not in doc.activations
senter.save_activations = True
doc = nlp("This is a test.")
assert "senter" in doc.activations
assert set(doc.activations["senter"].keys()) == {"label_ids", "probabilities"}
assert doc.activations["senter"]["probabilities"].shape == (5, nO)
assert doc.activations["senter"]["label_ids"].shape == (5,)

View File

@ -419,3 +419,23 @@ def test_set_candidates():
assert len(docs[0].spans["candidates"]) == 9 assert len(docs[0].spans["candidates"]) == 9
assert docs[0].spans["candidates"][0].text == "Just" assert docs[0].spans["candidates"][0].text == "Just"
assert docs[0].spans["candidates"][4].text == "Just a" assert docs[0].spans["candidates"][4].text == "Just a"
def test_save_activations():
# Test if activations are correctly added to Doc when requested.
nlp = English()
spancat = nlp.add_pipe("spancat", config={"spans_key": SPAN_KEY})
train_examples = make_examples(nlp)
nlp.initialize(get_examples=lambda: train_examples)
nO = spancat.model.get_dim("nO")
assert nO == 2
assert set(spancat.labels) == {"LOC", "PERSON"}
doc = nlp("This is a test.")
assert "spancat" not in doc.activations
spancat.save_activations = True
doc = nlp("This is a test.")
assert set(doc.activations["spancat"].keys()) == {"indices", "scores"}
assert doc.activations["spancat"]["indices"].shape == (12, 2)
assert doc.activations["spancat"]["scores"].shape == (12, nO)

View File

@ -1,3 +1,4 @@
from typing import cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from spacy.attrs import TAG from spacy.attrs import TAG
@ -6,6 +7,7 @@ from spacy import util
from spacy.training import Example from spacy.training import Example
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TrainablePipe
from thinc.api import compounding from thinc.api import compounding
from ..util import make_tempdir from ..util import make_tempdir
@ -211,6 +213,26 @@ def test_overfitting_IO():
assert doc3[0].tag_ != "N" assert doc3[0].tag_ != "N"
def test_save_activations():
# Test if activations are correctly added to Doc when requested.
nlp = English()
tagger = cast(TrainablePipe, nlp.add_pipe("tagger"))
train_examples = []
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
nlp.initialize(get_examples=lambda: train_examples)
doc = nlp("This is a test.")
assert "tagger" not in doc.activations
tagger.save_activations = True
doc = nlp("This is a test.")
assert "tagger" in doc.activations
assert set(doc.activations["tagger"].keys()) == {"label_ids", "probabilities"}
assert doc.activations["tagger"]["probabilities"].shape == (5, len(TAGS))
assert doc.activations["tagger"]["label_ids"].shape == (5,)
def test_tagger_requires_labels(): def test_tagger_requires_labels():
nlp = English() nlp = English()
nlp.add_pipe("tagger") nlp.add_pipe("tagger")

View File

@ -1,3 +1,4 @@
from typing import cast
import random import random
import numpy.random import numpy.random
@ -11,7 +12,7 @@ from spacy import util
from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat from spacy.cli.evaluate import print_prf_per_type, print_textcats_auc_per_cat
from spacy.lang.en import English from spacy.lang.en import English
from spacy.language import Language from spacy.language import Language
from spacy.pipeline import TextCategorizer from spacy.pipeline import TextCategorizer, TrainablePipe
from spacy.pipeline.textcat import single_label_bow_config from spacy.pipeline.textcat import single_label_bow_config
from spacy.pipeline.textcat import single_label_cnn_config from spacy.pipeline.textcat import single_label_cnn_config
from spacy.pipeline.textcat import single_label_default_config from spacy.pipeline.textcat import single_label_default_config
@ -285,7 +286,7 @@ def test_issue9904():
nlp.initialize(get_examples) nlp.initialize(get_examples)
examples = get_examples() examples = get_examples()
scores = textcat.predict([eg.predicted for eg in examples]) scores = textcat.predict([eg.predicted for eg in examples])["probabilities"]
loss = textcat.get_loss(examples, scores)[0] loss = textcat.get_loss(examples, scores)[0]
loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0] loss_double_bs = textcat.get_loss(examples * 2, scores.repeat(2, axis=0))[0]
@ -871,3 +872,41 @@ def test_textcat_multi_threshold():
scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0}) scores = nlp.evaluate(train_examples, scorer_cfg={"threshold": 0})
assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0 assert scores["cats_f_per_type"]["POSITIVE"]["r"] == 1.0
def test_save_activations():
nlp = English()
textcat = cast(TrainablePipe, nlp.add_pipe("textcat"))
train_examples = []
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
nlp.initialize(get_examples=lambda: train_examples)
nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.")
assert "textcat" not in doc.activations
textcat.save_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat"].keys()) == ["probabilities"]
assert doc.activations["textcat"]["probabilities"].shape == (nO,)
def test_save_activations_multi():
nlp = English()
textcat = cast(TrainablePipe, nlp.add_pipe("textcat_multilabel"))
train_examples = []
for text, annotations in TRAIN_DATA_MULTI_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
nlp.initialize(get_examples=lambda: train_examples)
nO = textcat.model.get_dim("nO")
doc = nlp("This is a test.")
assert "textcat_multilabel" not in doc.activations
textcat.save_activations = True
doc = nlp("This is a test.")
assert list(doc.activations["textcat_multilabel"].keys()) == ["probabilities"]
assert doc.activations["textcat_multilabel"]["probabilities"].shape == (nO,)

View File

@ -50,6 +50,8 @@ cdef class Doc:
cdef public float sentiment cdef public float sentiment
cdef public dict activations
cdef public dict user_hooks cdef public dict user_hooks
cdef public dict user_token_hooks cdef public dict user_token_hooks
cdef public dict user_span_hooks cdef public dict user_span_hooks

View File

@ -1,7 +1,7 @@
from typing import Callable, Protocol, Iterable, Iterator, Optional from typing import Callable, Protocol, Iterable, Iterator, Optional
from typing import Union, Tuple, List, Dict, Any, overload from typing import Union, Tuple, List, Dict, Any, overload
from cymem.cymem import Pool from cymem.cymem import Pool
from thinc.types import Floats1d, Floats2d, Ints2d from thinc.types import ArrayXd, Floats1d, Floats2d, Ints2d, Ragged
from .span import Span from .span import Span
from .token import Token from .token import Token
from .span_groups import SpanGroups from .span_groups import SpanGroups
@ -22,6 +22,7 @@ class Doc:
max_length: int max_length: int
length: int length: int
sentiment: float sentiment: float
activations: Dict[str, Dict[str, Union[ArrayXd, Ragged]]]
cats: Dict[str, float] cats: Dict[str, float]
user_hooks: Dict[str, Callable[..., Any]] user_hooks: Dict[str, Callable[..., Any]]
user_token_hooks: Dict[str, Callable[..., Any]] user_token_hooks: Dict[str, Callable[..., Any]]

View File

@ -245,6 +245,7 @@ cdef class Doc:
self.length = 0 self.length = 0
self.sentiment = 0.0 self.sentiment = 0.0
self.cats = {} self.cats = {}
self.activations = {}
self.user_hooks = {} self.user_hooks = {}
self.user_token_hooks = {} self.user_token_hooks = {}
self.user_span_hooks = {} self.user_span_hooks = {}

View File

@ -751,22 +751,23 @@ The L2 norm of the document's vector representation.
## Attributes {#attributes} ## Attributes {#attributes}
| Name | Description | | Name | Description |
| ------------------------------------ | ----------------------------------------------------------------------------------------------------------------------------------- | | ------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------- |
| `text` | A string representation of the document text. ~~str~~ | | `text` | A string representation of the document text. ~~str~~ |
| `text_with_ws` | An alias of `Doc.text`, provided for duck-type compatibility with `Span` and `Token`. ~~str~~ | | `text_with_ws` | An alias of `Doc.text`, provided for duck-type compatibility with `Span` and `Token`. ~~str~~ |
| `mem` | The document's local memory heap, for all C data it owns. ~~cymem.Pool~~ | | `mem` | The document's local memory heap, for all C data it owns. ~~cymem.Pool~~ |
| `vocab` | The store of lexical types. ~~Vocab~~ | | `vocab` | The store of lexical types. ~~Vocab~~ |
| `tensor` <Tag variant="new">2</Tag> | Container for dense vector representations. ~~numpy.ndarray~~ | | `tensor` <Tag variant="new">2</Tag> | Container for dense vector representations. ~~numpy.ndarray~~ |
| `user_data` | A generic storage area, for user custom data. ~~Dict[str, Any]~~ | | `user_data` | A generic storage area, for user custom data. ~~Dict[str, Any]~~ |
| `lang` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~int~~ | | `lang` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~int~~ |
| `lang_` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~str~~ | | `lang_` <Tag variant="new">2.1</Tag> | Language of the document's vocabulary. ~~str~~ |
| `sentiment` | The document's positivity/negativity score, if available. ~~float~~ | | `sentiment` | The document's positivity/negativity score, if available. ~~float~~ |
| `user_hooks` | A dictionary that allows customization of the `Doc`'s properties. ~~Dict[str, Callable]~~ | | `user_hooks` | A dictionary that allows customization of the `Doc`'s properties. ~~Dict[str, Callable]~~ |
| `user_token_hooks` | A dictionary that allows customization of properties of `Token` children. ~~Dict[str, Callable]~~ | | `user_token_hooks` | A dictionary that allows customization of properties of `Token` children. ~~Dict[str, Callable]~~ |
| `user_span_hooks` | A dictionary that allows customization of properties of `Span` children. ~~Dict[str, Callable]~~ | | `user_span_hooks` | A dictionary that allows customization of properties of `Span` children. ~~Dict[str, Callable]~~ |
| `has_unknown_spaces` | Whether the document was constructed without known spacing between tokens (typically when created from gold tokenization). ~~bool~~ | | `has_unknown_spaces` | Whether the document was constructed without known spacing between tokens (typically when created from gold tokenization). ~~bool~~ |
| `_` | User space for adding custom [attribute extensions](/usage/processing-pipelines#custom-components-attributes). ~~Underscore~~ | | `_` | User space for adding custom [attribute extensions](/usage/processing-pipelines#custom-components-attributes). ~~Underscore~~ |
| `activations` <Tag variant="new">4.0</Tag> | A dictionary of activations per trainable pipe (available when the `save_activations` option of a pipe is enabled). ~~Dict[str, Option[Any]]~~ |
## Serialization fields {#serialization-fields} ## Serialization fields {#serialization-fields}

View File

@ -44,14 +44,15 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("trainable_lemmatizer", config=config, name="lemmatizer") > nlp.add_pipe("trainable_lemmatizer", config=config, name="lemmatizer")
> ``` > ```
| Setting | Description | | Setting | Description |
| --------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ----------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `model` | A model instance that predicts the edit tree probabilities. The output vectors should match the number of edit trees in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ | | `model` | A model instance that predicts the edit tree probabilities. The output vectors should match the number of edit trees in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `backoff` | ~~Token~~ attribute to use when no applicable edit tree is found. Defaults to `orth`. ~~str~~ | | `backoff` | ~~Token~~ attribute to use when no applicable edit tree is found. Defaults to `orth`. ~~str~~ |
| `min_tree_freq` | Minimum frequency of an edit tree in the training set to be used. Defaults to `3`. ~~int~~ | | `min_tree_freq` | Minimum frequency of an edit tree in the training set to be used. Defaults to `3`. ~~int~~ |
| `overwrite` | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ | | `overwrite` | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ |
| `top_k` | The number of most probable edit trees to try before resorting to `backoff`. Defaults to `1`. ~~int~~ | | `top_k` | The number of most probable edit trees to try before resorting to `backoff`. Defaults to `1`. ~~int~~ |
| `scorer` | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"lemma"`. ~~Optional[Callable]~~ | | `scorer` | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"lemma"`. ~~Optional[Callable]~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"probabilities"` and `"tree_ids"`. ~~Union[bool, list[str]]~~ |
```python ```python
%%GITHUB_SPACY/spacy/pipeline/edit_tree_lemmatizer.py %%GITHUB_SPACY/spacy/pipeline/edit_tree_lemmatizer.py

View File

@ -52,19 +52,20 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("entity_linker", config=config) > nlp.add_pipe("entity_linker", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |
| ---------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ | | `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ | | `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ | | `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ | | `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ | | `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ | | `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ | | `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ | | `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ | | `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"ents"` and `"scores"`. ~~Union[bool, list[str]]~~ |
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
```python ```python
%%GITHUB_SPACY/spacy/pipeline/entity_linker.py %%GITHUB_SPACY/spacy/pipeline/entity_linker.py

View File

@ -42,12 +42,13 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("morphologizer", config=config) > nlp.add_pipe("morphologizer", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |
| ---------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `model` | The model to use. Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ | | `model` | The model to use. Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether the values of existing features are overwritten. Defaults to `True`. ~~bool~~ | | `overwrite` <Tag variant="new">3.2</Tag> | Whether the values of existing features are overwritten. Defaults to `True`. ~~bool~~ |
| `extend` <Tag variant="new">3.2</Tag> | Whether existing feature types (whose values may or may not be overwritten depending on `overwrite`) are preserved. Defaults to `False`. ~~bool~~ | | `extend` <Tag variant="new">3.2</Tag> | Whether existing feature types (whose values may or may not be overwritten depending on `overwrite`) are preserved. Defaults to `False`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attributes `"pos"` and `"morph"` and [`Scorer.score_token_attr_per_feat`](/api/scorer#score_token_attr_per_feat) for the attribute `"morph"`. ~~Optional[Callable]~~ | | `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attributes `"pos"` and `"morph"` and [`Scorer.score_token_attr_per_feat`](/api/scorer#score_token_attr_per_feat) for the attribute `"morph"`. ~~Optional[Callable]~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"probabilities"` and `"label_ids"`. ~~Union[bool, list[str]]~~ |
```python ```python
%%GITHUB_SPACY/spacy/pipeline/morphologizer.pyx %%GITHUB_SPACY/spacy/pipeline/morphologizer.pyx
@ -399,8 +400,8 @@ coarse-grained POS as the feature `POS`.
> assert "Mood=Ind|POS=VERB|Tense=Past|VerbForm=Fin" in morphologizer.labels > assert "Mood=Ind|POS=VERB|Tense=Past|VerbForm=Fin" in morphologizer.labels
> ``` > ```
| Name | Description | | Name | Description |
| ----------- | ------------------------------------------------------ | | ----------- | --------------------------------------------------------- |
| **RETURNS** | The labels added to the component. ~~Iterable[str, ...]~~ | | **RETURNS** | The labels added to the component. ~~Iterable[str, ...]~~ |
## Morphologizer.label_data {#label_data tag="property" new="3"} ## Morphologizer.label_data {#label_data tag="property" new="3"}

View File

@ -39,11 +39,12 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("senter", config=config) > nlp.add_pipe("senter", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |
| ---------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ | | `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for the attribute `"sents"`. ~~Optional[Callable]~~ | | `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for the attribute `"sents"`. ~~Optional[Callable]~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"probabilities"` and `"label_ids"`. ~~Union[bool, list[str]]~~ |
```python ```python
%%GITHUB_SPACY/spacy/pipeline/senter.pyx %%GITHUB_SPACY/spacy/pipeline/senter.pyx

View File

@ -52,14 +52,15 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("spancat", config=config) > nlp.add_pipe("spancat", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ | | `suggester` | A function that [suggests spans](#suggesters). Spans are returned as a ragged array with two integer columns, for the start and end positions. Defaults to [`ngram_suggester`](#ngram_suggester). ~~Callable[[Iterable[Doc], Optional[Ops]], Ragged]~~ |
| `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ | | `model` | A model instance that is given a a list of documents and `(start, end)` indices representing candidate span offsets. The model predicts a probability for each category for each span. Defaults to [SpanCategorizer](/api/architectures#SpanCategorizer). ~~Model[Tuple[List[Doc], Ragged], Floats2d]~~ |
| `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ | | `spans_key` | Key of the [`Doc.spans`](/api/doc#spans) dict to save the spans under. During initialization and training, the component will look for spans on the reference document under the same key. Defaults to `"sc"`. ~~str~~ |
| `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ | | `threshold` | Minimum probability to consider a prediction positive. Spans with a positive prediction will be saved on the Doc. Defaults to `0.5`. ~~float~~ |
| `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ | | `max_positive` | Maximum number of labels to consider positive per span. Defaults to `None`, indicating no limit. ~~Optional[int]~~ |
| `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ | | `scorer` | The scoring method. Defaults to [`Scorer.score_spans`](/api/scorer#score_spans) for `Doc.spans[spans_key]` with overlapping spans allowed. ~~Optional[Callable]~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"indices"` and `"scores"`. ~~Union[bool, list[str]]~~ |
```python ```python
%%GITHUB_SPACY/spacy/pipeline/spancat.py %%GITHUB_SPACY/spacy/pipeline/spancat.py

View File

@ -40,12 +40,13 @@ architectures and their arguments and hyperparameters.
> nlp.add_pipe("tagger", config=config) > nlp.add_pipe("tagger", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |
| ------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | | ----------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `model` | A model instance that predicts the tag probabilities. The output vectors should match the number of tags in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ | | `model` | A model instance that predicts the tag probabilities. The output vectors should match the number of tags in size, and be normalized as probabilities (all scores between 0 and 1, with the rows summing to `1`). Defaults to [Tagger](/api/architectures#Tagger). ~~Model[List[Doc], List[Floats2d]]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ | | `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"tag"`. ~~Optional[Callable]~~ | | `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_token_attr`](/api/scorer#score_token_attr) for the attribute `"tag"`. ~~Optional[Callable]~~ |
| `neg_prefix` <Tag variant="new">3.2.1</Tag> | The prefix used to specify incorrect tags while training. The tagger will learn not to predict exactly this tag. Defaults to `!`. ~~str~~ | | `neg_prefix` <Tag variant="new">3.2.1</Tag> | The prefix used to specify incorrect tags while training. The tagger will learn not to predict exactly this tag. Defaults to `!`. ~~str~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"probabilities"` and `"label_ids"`. ~~Union[bool, list[str]]~~ |
```python ```python
%%GITHUB_SPACY/spacy/pipeline/tagger.pyx %%GITHUB_SPACY/spacy/pipeline/tagger.pyx

View File

@ -117,14 +117,15 @@ Create a new pipeline instance. In your application, you would normally use a
shortcut for this and instantiate the component using its string name and shortcut for this and instantiate the component using its string name and
[`nlp.add_pipe`](/api/language#create_pipe). [`nlp.add_pipe`](/api/language#create_pipe).
| Name | Description | | Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------- | | ----------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------- |
| `vocab` | The shared vocabulary. ~~Vocab~~ | | `vocab` | The shared vocabulary. ~~Vocab~~ |
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ | | `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | | `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ | | `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
| `scorer` | The scoring method. Defaults to [`Scorer.score_cats`](/api/scorer#score_cats) for the attribute `"cats"`. ~~Optional[Callable]~~ | | `scorer` | The scoring method. Defaults to [`Scorer.score_cats`](/api/scorer#score_cats) for the attribute `"cats"`. ~~Optional[Callable]~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. The supported activations is `"probabilities"`. ~~Union[bool, list[str]]~~ |
## TextCategorizer.\_\_call\_\_ {#call tag="method"} ## TextCategorizer.\_\_call\_\_ {#call tag="method"}