diff --git a/spacy/pipeline/morphologizer.pyx b/spacy/pipeline/morphologizer.pyx index 0c7eacd12..df0151ffd 100644 --- a/spacy/pipeline/morphologizer.pyx +++ b/spacy/pipeline/morphologizer.pyx @@ -1,7 +1,8 @@ # cython: infer_types=True, profile=True, binding=True -from typing import Callable, Dict, List, Optional, Union +from typing import Callable, Dict, Iterable, List, Optional, Union import srsly from thinc.api import SequenceCategoricalCrossentropy, Model, Config +from thinc.types import Floats2d, Ints1d from itertools import islice from ..tokens.doc cimport Doc @@ -229,7 +230,7 @@ class Morphologizer(Tagger): assert len(label_sample) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=doc_sample, Y=label_sample) - def set_annotations(self, docs, activations): + def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ints1d]]): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. diff --git a/spacy/pipeline/senter.pyx b/spacy/pipeline/senter.pyx index 1cfd6c4b1..b637ad0ef 100644 --- a/spacy/pipeline/senter.pyx +++ b/spacy/pipeline/senter.pyx @@ -1,9 +1,10 @@ # cython: infer_types=True, profile=True, binding=True -from typing import Optional, Callable, List, Union +from typing import Dict, Iterable, Optional, Callable, List, Union from itertools import islice import srsly from thinc.api import Model, SequenceCategoricalCrossentropy, Config +from thinc.types import Floats2d, Ints1d from ..tokens.doc cimport Doc @@ -121,7 +122,7 @@ class SentenceRecognizer(Tagger): def label_data(self): return None - def set_annotations(self, docs, activations): + def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. diff --git a/spacy/pipeline/spancat.py b/spacy/pipeline/spancat.py index ccd49bbac..4484f7577 100644 --- a/spacy/pipeline/spancat.py +++ b/spacy/pipeline/spancat.py @@ -298,7 +298,9 @@ class SpanCategorizer(TrainablePipe): for index in candidates.dataXd: doc.spans[candidates_key].append(doc[index[0] : index[1]]) - def set_annotations(self, docs: Iterable[Doc], activations) -> None: + def set_annotations( + self, docs: Iterable[Doc], activations: Dict[str, Union[Floats2d, Ragged]] + ) -> None: """Modify a batch of Doc objects, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. @@ -309,7 +311,9 @@ class SpanCategorizer(TrainablePipe): labels = self.labels indices = activations["indices"] - scores = activations["scores"] + assert isinstance(indices, Ragged) + scores = cast(Floats2d, activations["scores"]) + offset = 0 for i, doc in enumerate(docs): indices_i = indices[i].dataXd diff --git a/spacy/pipeline/tagger.pyx b/spacy/pipeline/tagger.pyx index 498b3de08..fbaccae40 100644 --- a/spacy/pipeline/tagger.pyx +++ b/spacy/pipeline/tagger.pyx @@ -1,9 +1,9 @@ # cython: infer_types=True, profile=True, binding=True -from typing import Callable, List, Optional, Union +from typing import Callable, Dict, Iterable, List, Optional, Union import numpy import srsly from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config -from thinc.types import Floats2d +from thinc.types import Floats2d, Ints1d import warnings from itertools import islice @@ -167,7 +167,7 @@ class Tagger(TrainablePipe): guesses.append(doc_guesses) return guesses - def set_annotations(self, docs, activations): + def set_annotations(self, docs: Iterable[Doc], activations: Dict[str, Union[List[Floats2d], List[Ints1d]]]): """Modify a batch of documents, using pre-computed scores. docs (Iterable[Doc]): The documents to modify. diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 1ca112060..888cd0178 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -212,7 +212,7 @@ class TextCategorizer(TrainablePipe): scores = self.model.ops.asarray(scores) return scores - def set_annotations(self, docs: Iterable[Doc], scores) -> None: + def set_annotations(self, docs: Iterable[Doc], scores: Floats2d) -> None: """Modify a batch of Doc objects, using pre-computed scores. docs (Iterable[Doc]): The documents to modify.