mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-29 11:26:28 +03:00
EntityLinker: add type annotations to _add_activations
This commit is contained in:
parent
230264daa0
commit
57caae8393
|
@ -1,7 +1,7 @@
|
|||
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
|
||||
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
|
||||
from typing import cast
|
||||
from numpy import dtype
|
||||
from thinc.types import Floats2d, Ragged
|
||||
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
||||
from pathlib import Path
|
||||
from itertools import islice
|
||||
import srsly
|
||||
|
@ -431,12 +431,12 @@ class EntityLinker(TrainablePipe):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
for doc in docs:
|
||||
doc_ents = []
|
||||
doc_scores = []
|
||||
doc_ents: List[Ints1d] = []
|
||||
doc_scores: List[Floats1d] = []
|
||||
doc_scores_lens: List[int] = []
|
||||
if len(doc) == 0:
|
||||
doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
||||
doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
||||
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
||||
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
||||
continue
|
||||
sentences = [s for s in doc.sents]
|
||||
# Looping through each entity (TODO: rewrite)
|
||||
|
@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe):
|
|||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
self._add_activations(
|
||||
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
|
||||
doc_scores,
|
||||
doc_scores_lens,
|
||||
doc_ents,
|
||||
[1.0],
|
||||
[candidates[0].entity_],
|
||||
)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
@ -651,7 +655,12 @@ class EntityLinker(TrainablePipe):
|
|||
return ["ents", "scores"]
|
||||
|
||||
def _add_doc_activations(
|
||||
self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
|
||||
self,
|
||||
docs_scores: List[Ragged],
|
||||
docs_ents: List[Ragged],
|
||||
doc_scores: List[Floats1d],
|
||||
doc_scores_lens: List[int],
|
||||
doc_ents: List[Ints1d],
|
||||
):
|
||||
if len(self.store_activations) == 0:
|
||||
return
|
||||
|
@ -665,10 +674,17 @@ class EntityLinker(TrainablePipe):
|
|||
)
|
||||
)
|
||||
|
||||
def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents):
|
||||
def _add_activations(
|
||||
self,
|
||||
doc_scores: List[Floats1d],
|
||||
doc_scores_lens: List[int],
|
||||
doc_ents: List[Ints1d],
|
||||
scores: Sequence[float],
|
||||
ents: Sequence[int],
|
||||
):
|
||||
if len(self.store_activations) == 0:
|
||||
return
|
||||
ops = self.model.ops
|
||||
doc_scores.append(ops.asarray1f(scores))
|
||||
doc_scores_lens.append(doc_scores[-1].shape[0])
|
||||
doc_ents.append(ops.xp.array(ents, dtype="uint64"))
|
||||
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))
|
||||
|
|
Loading…
Reference in New Issue
Block a user