_store_activations: make kwarg-only, remove doc_scores_lens arg

This commit is contained in:
Daniël de Kok 2022-08-05 12:00:01 +02:00
parent 57caae8393
commit ce36f345db

View File

@ -433,7 +433,6 @@ class EntityLinker(TrainablePipe):
for doc in docs:
doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = []
doc_scores_lens: List[int] = []
if len(doc) == 0:
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
@ -460,7 +459,10 @@ class EntityLinker(TrainablePipe):
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[0.0],
ents=[0],
)
else:
candidates = list(self.get_candidates(self.kb, ent))
@ -468,17 +470,19 @@ class EntityLinker(TrainablePipe):
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[0.0],
ents=[0],
)
elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_)
self._add_activations(
doc_scores,
doc_scores_lens,
doc_ents,
[1.0],
[candidates[0].entity_],
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[1.0],
ents=[candidates[0].entity_],
)
else:
random.shuffle(candidates)
@ -513,14 +517,16 @@ class EntityLinker(TrainablePipe):
else EntityLinker.NIL
)
self._add_activations(
doc_scores,
doc_scores_lens,
doc_ents,
scores,
[c.entity for c in candidates],
doc_scores=doc_scores,
doc_ents=doc_ents,
scores=scores,
ents=[c.entity for c in candidates],
)
self._add_doc_activations(
docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
docs_scores=docs_scores,
docs_ents=docs_ents,
doc_scores=doc_scores,
doc_ents=doc_ents,
)
if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format(
@ -656,28 +662,23 @@ class EntityLinker(TrainablePipe):
def _add_doc_activations(
self,
*,
docs_scores: List[Ragged],
docs_ents: List[Ragged],
doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d],
):
if len(self.store_activations) == 0:
return
ops = self.model.ops
docs_scores.append(
Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens))
)
docs_ents.append(
Ragged(
ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
)
)
lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
docs_scores.append(Ragged(ops.flatten(doc_scores), lengths))
docs_ents.append(Ragged(ops.flatten(doc_ents), lengths))
def _add_activations(
self,
*,
doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d],
scores: Sequence[float],
ents: Sequence[int],
@ -686,5 +687,4 @@ class EntityLinker(TrainablePipe):
return
ops = self.model.ops
doc_scores.append(ops.asarray1f(scores))
doc_scores_lens.append(doc_scores[-1].shape[0])
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))