mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-01 04:46:38 +03:00
EntityLinker: add type annotations to _add_activations
This commit is contained in:
parent
230264daa0
commit
57caae8393
|
@ -1,7 +1,7 @@
|
||||||
from typing import Optional, Iterable, Callable, Dict, Union, List, Any
|
from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from numpy import dtype
|
from numpy import dtype
|
||||||
from thinc.types import Floats2d, Ragged
|
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
import srsly
|
import srsly
|
||||||
|
@ -431,12 +431,12 @@ class EntityLinker(TrainablePipe):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
doc_ents = []
|
doc_ents: List[Ints1d] = []
|
||||||
doc_scores = []
|
doc_scores: List[Floats1d] = []
|
||||||
doc_scores_lens: List[int] = []
|
doc_scores_lens: List[int] = []
|
||||||
if len(doc) == 0:
|
if len(doc) == 0:
|
||||||
doc_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
||||||
doc_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
||||||
continue
|
continue
|
||||||
sentences = [s for s in doc.sents]
|
sentences = [s for s in doc.sents]
|
||||||
# Looping through each entity (TODO: rewrite)
|
# Looping through each entity (TODO: rewrite)
|
||||||
|
@ -474,7 +474,11 @@ class EntityLinker(TrainablePipe):
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
self._add_activations(
|
self._add_activations(
|
||||||
doc_scores, doc_scores_lens, doc_ents, [1.0], [candidates[0].entity_]
|
doc_scores,
|
||||||
|
doc_scores_lens,
|
||||||
|
doc_ents,
|
||||||
|
[1.0],
|
||||||
|
[candidates[0].entity_],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
@ -651,7 +655,12 @@ class EntityLinker(TrainablePipe):
|
||||||
return ["ents", "scores"]
|
return ["ents", "scores"]
|
||||||
|
|
||||||
def _add_doc_activations(
|
def _add_doc_activations(
|
||||||
self, docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
|
self,
|
||||||
|
docs_scores: List[Ragged],
|
||||||
|
docs_ents: List[Ragged],
|
||||||
|
doc_scores: List[Floats1d],
|
||||||
|
doc_scores_lens: List[int],
|
||||||
|
doc_ents: List[Ints1d],
|
||||||
):
|
):
|
||||||
if len(self.store_activations) == 0:
|
if len(self.store_activations) == 0:
|
||||||
return
|
return
|
||||||
|
@ -665,10 +674,17 @@ class EntityLinker(TrainablePipe):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def _add_activations(self, doc_scores, doc_scores_lens, doc_ents, scores, ents):
|
def _add_activations(
|
||||||
|
self,
|
||||||
|
doc_scores: List[Floats1d],
|
||||||
|
doc_scores_lens: List[int],
|
||||||
|
doc_ents: List[Ints1d],
|
||||||
|
scores: Sequence[float],
|
||||||
|
ents: Sequence[int],
|
||||||
|
):
|
||||||
if len(self.store_activations) == 0:
|
if len(self.store_activations) == 0:
|
||||||
return
|
return
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
doc_scores.append(ops.asarray1f(scores))
|
doc_scores.append(ops.asarray1f(scores))
|
||||||
doc_scores_lens.append(doc_scores[-1].shape[0])
|
doc_scores_lens.append(doc_scores[-1].shape[0])
|
||||||
doc_ents.append(ops.xp.array(ents, dtype="uint64"))
|
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user