EntityLinker: add type annotations to _add_activations

This commit is contained in:
Daniël de Kok 2022-08-05 11:48:04 +02:00
parent 230264daa0
commit 57caae8393

View File

@ -1,7 +1,7 @@
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
from typing import cast
from numpy import dtype
from thinc.types import Floats2d, Ragged
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from pathlib import Path
from itertools import islice
import srsly
@ -431,12 +431,12 @@ class EntityLinker(TrainablePipe):
if isinstance(docs, Doc):
docs = [docs]
for doc in docs:
doc_ents = []
doc_scores = []
doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = []
doc_scores_lens: List[int] = []
if len(doc) == 0:
doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
continue
sentences = [s for s in doc.sents]
# Looping through each entity (TODO: rewrite)
@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe):
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
doc_scores,
doc_scores_lens,
doc_ents,
[1.0],
[candidates[0].entity_],
)
else:
random.shuffle(candidates)
@ -651,7 +655,12 @@ class EntityLinker(TrainablePipe):
return ["ents", "scores"]
def _add_doc_activations(
self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
self,
docs_scores: List[Ragged],
docs_ents: List[Ragged],
doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d],
):
if len(self.store_activations) == 0:
return
@ -665,10 +674,17 @@ class EntityLinker(TrainablePipe):
)
)
def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents):
def _add_activations(
self,
doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d],
scores: Sequence[float],
ents: Sequence[int],
):
if len(self.store_activations) == 0:
return
ops = self.model.ops
doc_scores.append(ops.asarray1f(scores))
doc_scores_lens.append(doc_scores[-1].shape[0])
doc_ents.append(ops.xp.array(ents, dtype="uint64"))
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))