diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 266c1f07f..71458fe3f 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,7 +1,7 @@ -from typing import Optional, Iterable, Callable, Dict, Union, List, Any +from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any from typing import cast from numpy import dtype -from thinc.types import Floats2d, Ragged +from thinc.types import Floats1d, Floats2d, Ints1d, Ragged from pathlib import Path from itertools import islice import srsly @@ -431,12 +431,12 @@ class EntityLinker(TrainablePipe): if isinstance(docs, Doc): docs = [docs] for doc in docs: - doc_ents = [] - doc_scores = [] + doc_ents: List[Ints1d] = [] + doc_scores: List[Floats1d] = [] doc_scores_lens: List[int] = [] if len(doc) == 0: - doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0))) - doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0))) + docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0))) + docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0))) continue sentences = [s for s in doc.sents] # Looping through each entity (TODO: rewrite) @@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe): # shortcut for efficiency reasons: take the 1 candidate final_kb_ids.append(candidates[0].entity_) self._add_activations( - doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_] + doc_scores, + doc_scores_lens, + doc_ents, + [1.0], + [candidates[0].entity_], ) else: random.shuffle(candidates) @@ -651,7 +655,12 @@ class EntityLinker(TrainablePipe): return ["ents", "scores"] def _add_doc_activations( - self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents + self, + docs_scores: List[Ragged], + docs_ents: List[Ragged], + doc_scores: List[Floats1d], + doc_scores_lens: List[int], + doc_ents: List[Ints1d], ): if len(self.store_activations) == 0: return @@ -665,10 +674,17 @@ class EntityLinker(TrainablePipe): ) ) - def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents): + def _add_activations( + self, + doc_scores: List[Floats1d], + doc_scores_lens: List[int], + doc_ents: List[Ints1d], + scores: Sequence[float], + ents: Sequence[int], + ): if len(self.store_activations) == 0: return ops = self.model.ops doc_scores.append(ops.asarray1f(scores)) doc_scores_lens.append(doc_scores[-1].shape[0]) - doc_ents.append(ops.xp.array(ents, dtype="uint64")) + doc_ents.append(ops.asarray1i(ents, dtype="uint64"))