mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-02 12:44:10 +03:00
_store_activations: make kwarg-only, remove doc_scores_lens arg
This commit is contained in:
parent
57caae8393
commit
ce36f345db
|
@ -433,7 +433,6 @@ class EntityLinker(TrainablePipe):
|
|||
for doc in docs:
|
||||
doc_ents: List[Ints1d] = []
|
||||
doc_scores: List[Floats1d] = []
|
||||
doc_scores_lens: List[int] = []
|
||||
if len(doc) == 0:
|
||||
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
||||
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
||||
|
@ -460,7 +459,10 @@ class EntityLinker(TrainablePipe):
|
|||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
self._add_activations(
|
||||
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=[0.0],
|
||||
ents=[0],
|
||||
)
|
||||
else:
|
||||
candidates = list(self.get_candidates(self.kb, ent))
|
||||
|
@ -468,17 +470,19 @@ class EntityLinker(TrainablePipe):
|
|||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
self._add_activations(
|
||||
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=[0.0],
|
||||
ents=[0],
|
||||
)
|
||||
elif len(candidates) == 1 and self.threshold is None:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
self._add_activations(
|
||||
doc_scores,
|
||||
doc_scores_lens,
|
||||
doc_ents,
|
||||
[1.0],
|
||||
[candidates[0].entity_],
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=[1.0],
|
||||
ents=[candidates[0].entity_],
|
||||
)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
@ -513,14 +517,16 @@ class EntityLinker(TrainablePipe):
|
|||
else EntityLinker.NIL
|
||||
)
|
||||
self._add_activations(
|
||||
doc_scores,
|
||||
doc_scores_lens,
|
||||
doc_ents,
|
||||
scores,
|
||||
[c.entity for c in candidates],
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
scores=scores,
|
||||
ents=[c.entity for c in candidates],
|
||||
)
|
||||
self._add_doc_activations(
|
||||
docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
|
||||
docs_scores=docs_scores,
|
||||
docs_ents=docs_ents,
|
||||
doc_scores=doc_scores,
|
||||
doc_ents=doc_ents,
|
||||
)
|
||||
if not (len(final_kb_ids) == entity_count):
|
||||
err = Errors.E147.format(
|
||||
|
@ -656,28 +662,23 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
def _add_doc_activations(
|
||||
self,
|
||||
*,
|
||||
docs_scores: List[Ragged],
|
||||
docs_ents: List[Ragged],
|
||||
doc_scores: List[Floats1d],
|
||||
doc_scores_lens: List[int],
|
||||
doc_ents: List[Ints1d],
|
||||
):
|
||||
if len(self.store_activations) == 0:
|
||||
return
|
||||
ops = self.model.ops
|
||||
docs_scores.append(
|
||||
Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens))
|
||||
)
|
||||
docs_ents.append(
|
||||
Ragged(
|
||||
ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
|
||||
)
|
||||
)
|
||||
lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
|
||||
docs_scores.append(Ragged(ops.flatten(doc_scores), lengths))
|
||||
docs_ents.append(Ragged(ops.flatten(doc_ents), lengths))
|
||||
|
||||
def _add_activations(
|
||||
self,
|
||||
*,
|
||||
doc_scores: List[Floats1d],
|
||||
doc_scores_lens: List[int],
|
||||
doc_ents: List[Ints1d],
|
||||
scores: Sequence[float],
|
||||
ents: Sequence[int],
|
||||
|
@ -686,5 +687,4 @@ class EntityLinker(TrainablePipe):
|
|||
return
|
||||
ops = self.model.ops
|
||||
doc_scores.append(ops.asarray1f(scores))
|
||||
doc_scores_lens.append(doc_scores[-1].shape[0])
|
||||
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))
|
||||
|
|
Loading…
Reference in New Issue
Block a user