Add type annotations for activations in predict/set_annotations

This commit is contained in:
Daniël de Kok 2022-08-30 10:07:33 +02:00
parent 8c2652d788
commit 8f84e6ea8a
5 changed files with 23 additions and 18 deletions

View File

@ -14,7 +14,7 @@ from ..symbols import POS
from ..language import Language
from ..errors import Errors
from .pipe import deserialize_config
from .tagger import Tagger
from .tagger import ActivationsT, Tagger
from .. import util
from ..scorer import Scorer
from ..training import validate_examples, validate_get_examples
@ -229,11 +229,11 @@ class Morphologizer(Tagger):
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample)
def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ints1d]]):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
activations (Dict): The activations used for setting annotations, produced by Morphologizer.predict.
activations (ActivationsT): The activations used for setting annotations, produced by Morphologizer.predict.
DOCS: https://spacy.io/api/morphologizer#set_annotations
"""

View File

@ -8,7 +8,7 @@ from thinc.types import Floats2d, Ints1d
from ..tokens.doc cimport Doc
from .tagger import Tagger
from .tagger import ActivationsT, Tagger
from ..language import Language
from ..errors import Errors
from ..scorer import Scorer
@ -121,11 +121,11 @@ class SentenceRecognizer(Tagger):
def label_data(self):
return None
def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
activations (Dict): The activations used for setting annotations, produced by SentenceRecognizer.predict.
activations (ActivationsT): The activations used for setting annotations, produced by SentenceRecognizer.predict.
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
"""

View File

@ -17,6 +17,9 @@ from ..errors import Errors
from ..util import registry
ActivationsT = Dict[str, Union[Floats2d, Ragged]]
spancat_default_config = """
[model]
@architectures = "spacy.SpanCategorizer.v1"
@ -267,7 +270,7 @@ class SpanCategorizer(TrainablePipe):
"""
return list(self.labels)
def predict(self, docs: Iterable[Doc]):
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict.
@ -297,13 +300,11 @@ class SpanCategorizer(TrainablePipe):
for index in candidates.dataXd:
doc.spans[candidates_key].append(doc[index[0] : index[1]])
def set_annotations(
self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ragged]]
) -> None:
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
scores: The scores to set, produced by SpanCategorizer.predict.
activations: ActivationsT: The activations, produced by SpanCategorizer.predict.
DOCS: https://spacy.io/api/spancategorizer#set_annotations
"""

View File

@ -22,6 +22,9 @@ from ..training import validate_examples, validate_get_examples
from ..util import registry
from .. import util
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
# See #9050
BACKWARD_OVERWRITE = False
@ -137,7 +140,7 @@ class Tagger(TrainablePipe):
"""Data about the labels currently added to the component."""
return tuple(self.cfg["labels"])
def predict(self, docs):
def predict(self, docs) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict.
@ -166,11 +169,11 @@ class Tagger(TrainablePipe):
guesses.append(doc_guesses)
return guesses
def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]):
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.
activations (Dict): The activations used for setting annotations, produced by Tagger.predict.
activations (ActivationsT): The activations used for setting annotations, produced by Tagger.predict.
DOCS: https://spacy.io/api/tagger#set_annotations
"""

View File

@ -14,6 +14,9 @@ from ..util import registry
from ..vocab import Vocab
ActivationsT = Dict[str, Floats2d]
single_label_default_config = """
[model]
@architectures = "spacy.TextCatEnsemble.v2"
@ -193,7 +196,7 @@ class TextCategorizer(TrainablePipe):
"""
return self.labels # type: ignore[return-value]
def predict(self, docs: Iterable[Doc]):
def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict.
@ -211,9 +214,7 @@ class TextCategorizer(TrainablePipe):
scores = self.model.ops.asarray(scores)
return {"probs": scores}
def set_annotations(
self, docs: Iterable[Doc], activations: Dict[str, Floats2d]
) -> None:
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify.