Add type annotations for activations in predict/set_annotations

This commit is contained in:
Daniël de Kok 2022-08-30 10:07:33 +02:00
parent 8c2652d788
commit 8f84e6ea8a
5 changed files with 23 additions and 18 deletions

View File

@ -14,7 +14,7 @@ from ..symbols import POS
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from .pipe import deserialize_config from .pipe import deserialize_config
from .tagger import Tagger from .tagger import ActivationsT, Tagger
from .. import util from .. import util
from ..scorer import Scorer from ..scorer import Scorer
from ..training import validate_examples, validate_get_examples from ..training import validate_examples, validate_get_examples
@ -229,11 +229,11 @@ class Morphologizer(Tagger):
assert len(label_sample) > 0, Errors.E923.format(name=self.name) assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample) self.model.initialize(X=doc_sample, Y=label_sample)
def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ints1d]]): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
activations (Dict): The activations used for setting annotations, produced by Morphologizer.predict. activations (ActivationsT): The activations used for setting annotations, produced by Morphologizer.predict.
DOCS: https://spacy.io/api/morphologizer#set_annotations DOCS: https://spacy.io/api/morphologizer#set_annotations
""" """

View File

@ -8,7 +8,7 @@ from thinc.types import Floats2d, Ints1d
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from .tagger import Tagger from .tagger import ActivationsT, Tagger
from ..language import Language from ..language import Language
from ..errors import Errors from ..errors import Errors
from ..scorer import Scorer from ..scorer import Scorer
@ -121,11 +121,11 @@ class SentenceRecognizer(Tagger):
def label_data(self): def label_data(self):
return None return None
def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
activations (Dict): The activations used for setting annotations, produced by SentenceRecognizer.predict. activations (ActivationsT): The activations used for setting annotations, produced by SentenceRecognizer.predict.
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
""" """

View File

@ -17,6 +17,9 @@ from ..errors import Errors
from ..util import registry from ..util import registry
ActivationsT = Dict[str, Union[Floats2d, Ragged]]
spancat_default_config = """ spancat_default_config = """
[model] [model]
@architectures = "spacy.SpanCategorizer.v1" @architectures = "spacy.SpanCategorizer.v1"
@ -267,7 +270,7 @@ class SpanCategorizer(TrainablePipe):
""" """
return list(self.labels) return list(self.labels)
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
@ -297,13 +300,11 @@ class SpanCategorizer(TrainablePipe):
for index in candidates.dataXd: for index in candidates.dataXd:
doc.spans[candidates_key].append(doc[index[0] : index[1]]) doc.spans[candidates_key].append(doc[index[0] : index[1]])
def set_annotations( def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ragged]]
) -> None:
"""Modify a batch of Doc objects, using pre-computed scores. """Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
scores: The scores to set, produced by SpanCategorizer.predict. activations: ActivationsT: The activations, produced by SpanCategorizer.predict.
DOCS: https://spacy.io/api/spancategorizer#set_annotations DOCS: https://spacy.io/api/spancategorizer#set_annotations
""" """

View File

@ -22,6 +22,9 @@ from ..training import validate_examples, validate_get_examples
from ..util import registry from ..util import registry
from .. import util from .. import util
ActivationsT = Dict[str, Union[List[Floats2d], List[Ints1d]]]
# See #9050 # See #9050
BACKWARD_OVERWRITE = False BACKWARD_OVERWRITE = False
@ -137,7 +140,7 @@ class Tagger(TrainablePipe):
"""Data about the labels currently added to the component.""" """Data about the labels currently added to the component."""
return tuple(self.cfg["labels"]) return tuple(self.cfg["labels"])
def predict(self, docs): def predict(self, docs) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
@ -166,11 +169,11 @@ class Tagger(TrainablePipe):
guesses.append(doc_guesses) guesses.append(doc_guesses)
return guesses return guesses
def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]): def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT):
"""Modify a batch of documents, using pre-computed scores. """Modify a batch of documents, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.
activations (Dict): The activations used for setting annotations, produced by Tagger.predict. activations (ActivationsT): The activations used for setting annotations, produced by Tagger.predict.
DOCS: https://spacy.io/api/tagger#set_annotations DOCS: https://spacy.io/api/tagger#set_annotations
""" """

View File

@ -14,6 +14,9 @@ from ..util import registry
from ..vocab import Vocab from ..vocab import Vocab
ActivationsT = Dict[str, Floats2d]
single_label_default_config = """ single_label_default_config = """
[model] [model]
@architectures = "spacy.TextCatEnsemble.v2" @architectures = "spacy.TextCatEnsemble.v2"
@ -193,7 +196,7 @@ class TextCategorizer(TrainablePipe):
""" """
return self.labels # type: ignore[return-value] return self.labels # type: ignore[return-value]
def predict(self, docs: Iterable[Doc]): def predict(self, docs: Iterable[Doc]) -> ActivationsT:
"""Apply the pipeline's model to a batch of docs, without modifying them. """Apply the pipeline's model to a batch of docs, without modifying them.
docs (Iterable[Doc]): The documents to predict. docs (Iterable[Doc]): The documents to predict.
@ -211,9 +214,7 @@ class TextCategorizer(TrainablePipe):
scores = self.model.ops.asarray(scores) scores = self.model.ops.asarray(scores)
return {"probs": scores} return {"probs": scores}
def set_annotations( def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
self, docs: Iterable[Doc], activations: Dict[str, Floats2d]
) -> None:
"""Modify a batch of Doc objects, using pre-computed scores. """Modify a batch of Doc objects, using pre-computed scores.
docs (Iterable[Doc]): The documents to modify. docs (Iterable[Doc]): The documents to modify.