diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 73b19dd0d..2d60a841a 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -14,7 +14,7 @@ from ..symbols import POS from ..language import Language from ..errors import Errors from .pipe import deserialize_config -from .tagger import Tagger +from .tagger import ActivationsT, Tagger from .. import util from ..scorer import Scorer from ..training import validate_examples, validate_get_examples @@ -229,11 +229,11 @@ class Morphologizer(Tagger): assert len(label_sample) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=doc_sample, Y=label_sample) - def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ints1d]]): + def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. - activations (Dict): The activations used for setting annotations, produced by Morphologizer.predict. + activations (ActivationsT): The activations used for setting annotations, produced by Morphologizer.predict. DOCS: https://spacy.io/api/morphologizer#set_annotations """ diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index b8916fc69..ff2b9f384 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -8,7 +8,7 @@ from thinc.types import Floats2d, Ints1d from ..tokens.doc cimport Doc -from .tagger import Tagger +from .tagger import ActivationsT, Tagger from ..language import Language from ..errors import Errors from ..scorer import Scorer @@ -121,11 +121,11 @@ class SentenceRecognizer(Tagger): def label_data(self): return None - def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]): + def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. - activations (Dict): The activations used for setting annotations, produced by SentenceRecognizer.predict. + activations (ActivationsT): The activations used for setting annotations, produced by SentenceRecognizer.predict. DOCS: https://spacy.io/api/sentencerecognizer#set_annotations """ diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index 042c68e36..ae00690df 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -17,6 +17,9 @@ from ..errors import Errors from ..util import registry +ActivationsT = Dict[str, Union[Floats2d, Ragged]] + + spancat_default_config = """ [model] @architectures = "spacy.SpanCategorizer.v1" @@ -267,7 +270,7 @@ class SpanCategorizer(TrainablePipe): """ return list(self.labels) - def predict(self, docs: Iterable[Doc]): + def predict(self, docs: Iterable[Doc]) -> ActivationsT: """Apply the pipeline's model to a batch of docs, without modifying them. docs (Iterable[Doc]): The documents to predict. @@ -297,13 +300,11 @@ class SpanCategorizer(TrainablePipe): for index in candidates.dataXd: doc.spans[candidates_key].append(doc[index[0] : index[1]]) - def set_annotations( - self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ragged]] - ) -> None: + def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None: """Modify a batch of Doc objects, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. - scores: The scores to set, produced by SpanCategorizer.predict. + activations: ActivationsT: The activations, produced by SpanCategorizer.predict. DOCS: https://spacy.io/api/spancategorizer#set_annotations """ diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index a1fe3be72..12bdf209d 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -22,6 +22,9 @@ from ..training import validate_examples, validate_get_examples from ..util import registry from .. import util + +ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]] + # See #9050 BACKWARD_OVERWRITE = False @@ -137,7 +140,7 @@ class Tagger(TrainablePipe): """Data about the labels currently added to the component.""" return tuple(self.cfg["labels"]) - def predict(self, docs): + def predict(self, docs) -> ActivationsT: """Apply the pipeline's model to a batch of docs, without modifying them. docs (Iterable[Doc]): The documents to predict. @@ -166,11 +169,11 @@ class Tagger(TrainablePipe): guesses.append(doc_guesses) return guesses - def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]): + def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. - activations (Dict): The activations used for setting annotations, produced by Tagger.predict. + activations (ActivationsT): The activations used for setting annotations, produced by Tagger.predict. DOCS: https://spacy.io/api/tagger#set_annotations """ diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index acb8be406..761c42f4f 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -14,6 +14,9 @@ from ..util import registry from ..vocab import Vocab +ActivationsT = Dict[str, Floats2d] + + single_label_default_config = """ [model] @architectures = "spacy.TextCatEnsemble.v2" @@ -193,7 +196,7 @@ class TextCategorizer(TrainablePipe): """ return self.labels # type: ignore[return-value] - def predict(self, docs: Iterable[Doc]): + def predict(self, docs: Iterable[Doc]) -> ActivationsT: """Apply the pipeline's model to a batch of docs, without modifying them. docs (Iterable[Doc]): The documents to predict. @@ -211,9 +214,7 @@ class TextCategorizer(TrainablePipe): scores = self.model.ops.asarray(scores) return {"probs": scores} - def set_annotations( - self, docs: Iterable[Doc], activations: Dict[str, Floats2d] - ) -> None: + def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None: """Modify a batch of Doc objects, using pre-computed scores. docs (Iterable[Doc]): The documents to modify.