EntityLinker: add type annotations to _add_activations

This commit is contained in:
Daniël de Kok 2022-08-05 11:48:04 +02:00
parent 230264daa0
commit 57caae8393

View File

@ -1,7 +1,7 @@
from typing import Optional, Iterable, Callable, Dict, Union, List, Any from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
from typing import cast from typing import cast
from numpy import dtype from numpy import dtype
from thinc.types import Floats2d, Ragged from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from pathlib import Path from pathlib import Path
from itertools import islice from itertools import islice
import srsly import srsly
@ -431,12 +431,12 @@ class EntityLinker(TrainablePipe):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for doc in docs: for doc in docs:
doc_ents = [] doc_ents: List[Ints1d] = []
doc_scores = [] doc_scores: List[Floats1d] = []
doc_scores_lens: List[int] = [] doc_scores_lens: List[int] = []
if len(doc) == 0: if len(doc) == 0:
doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0))) docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0))) docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
continue continue
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
# Looping through each entity (TODO: rewrite) # Looping through each entity (TODO: rewrite)
@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe):
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_)
self._add_activations( self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_] doc_scores,
doc_scores_lens,
doc_ents,
[1.0],
[candidates[0].entity_],
) )
else: else:
random.shuffle(candidates) random.shuffle(candidates)
@ -651,7 +655,12 @@ class EntityLinker(TrainablePipe):
return ["ents", "scores"] return ["ents", "scores"]
def _add_doc_activations( def _add_doc_activations(
self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents self,
docs_scores: List[Ragged],
docs_ents: List[Ragged],
doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d],
): ):
if len(self.store_activations) == 0: if len(self.store_activations) == 0:
return return
@ -665,10 +674,17 @@ class EntityLinker(TrainablePipe):
) )
) )
def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents): def _add_activations(
self,
doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d],
scores: Sequence[float],
ents: Sequence[int],
):
if len(self.store_activations) == 0: if len(self.store_activations) == 0:
return return
ops = self.model.ops ops = self.model.ops
doc_scores.append(ops.asarray1f(scores)) doc_scores.append(ops.asarray1f(scores))
doc_scores_lens.append(doc_scores[-1].shape[0]) doc_scores_lens.append(doc_scores[-1].shape[0])
doc_ents.append(ops.xp.array(ents, dtype="uint64")) doc_ents.append(ops.asarray1i(ents, dtype="uint64"))