_store_activations: make kwarg-only, remove doc_scores_lens arg

This commit is contained in:
Daniël de Kok 2022-08-05 12:00:01 +02:00
parent 57caae8393
commit ce36f345db

View File

@ -433,7 +433,6 @@ class EntityLinker(TrainablePipe):
for doc in docs: for doc in docs:
doc_ents: List[Ints1d] = [] doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = [] doc_scores: List[Floats1d] = []
doc_scores_lens: List[int] = []
if len(doc) == 0: if len(doc) == 0:
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0))) docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0))) docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
@ -460,7 +459,10 @@ class EntityLinker(TrainablePipe):
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
self._add_activations( self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0] doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[0.0],
ents=[0],
) )
else: else:
candidates = list(self.get_candidates(self.kb, ent)) candidates = list(self.get_candidates(self.kb, ent))
@ -468,17 +470,19 @@ class EntityLinker(TrainablePipe):
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
self._add_activations( self._add_activations(
doc_scores, doc_scores_lens, doc_ents, [0.0], [0] doc_scores=doc_scores,
doc_ents=doc_ents,
scores=[0.0],
ents=[0],
) )
elif len(candidates) == 1 and self.threshold is None: elif len(candidates) == 1 and self.threshold is None:
# shortcut for efficiency reasons: take the 1 candidate # shortcut for efficiency reasons: take the 1 candidate
final_kb_ids.append(candidates[0].entity_) final_kb_ids.append(candidates[0].entity_)
self._add_activations( self._add_activations(
doc_scores, doc_scores=doc_scores,
doc_scores_lens, doc_ents=doc_ents,
doc_ents, scores=[1.0],
[1.0], ents=[candidates[0].entity_],
[candidates[0].entity_],
) )
else: else:
random.shuffle(candidates) random.shuffle(candidates)
@ -513,14 +517,16 @@ class EntityLinker(TrainablePipe):
else EntityLinker.NIL else EntityLinker.NIL
) )
self._add_activations( self._add_activations(
doc_scores, doc_scores=doc_scores,
doc_scores_lens, doc_ents=doc_ents,
doc_ents, scores=scores,
scores, ents=[c.entity for c in candidates],
[c.entity for c in candidates],
) )
self._add_doc_activations( self._add_doc_activations(
docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents docs_scores=docs_scores,
docs_ents=docs_ents,
doc_scores=doc_scores,
doc_ents=doc_ents,
) )
if not (len(final_kb_ids) == entity_count): if not (len(final_kb_ids) == entity_count):
err = Errors.E147.format( err = Errors.E147.format(
@ -656,28 +662,23 @@ class EntityLinker(TrainablePipe):
def _add_doc_activations( def _add_doc_activations(
self, self,
*,
docs_scores: List[Ragged], docs_scores: List[Ragged],
docs_ents: List[Ragged], docs_ents: List[Ragged],
doc_scores: List[Floats1d], doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d], doc_ents: List[Ints1d],
): ):
if len(self.store_activations) == 0: if len(self.store_activations) == 0:
return return
ops = self.model.ops ops = self.model.ops
docs_scores.append( lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens)) docs_scores.append(Ragged(ops.flatten(doc_scores), lengths))
) docs_ents.append(Ragged(ops.flatten(doc_ents), lengths))
docs_ents.append(
Ragged(
ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
)
)
def _add_activations( def _add_activations(
self, self,
*,
doc_scores: List[Floats1d], doc_scores: List[Floats1d],
doc_scores_lens: List[int],
doc_ents: List[Ints1d], doc_ents: List[Ints1d],
scores: Sequence[float], scores: Sequence[float],
ents: Sequence[int], ents: Sequence[int],
@ -686,5 +687,4 @@ class EntityLinker(TrainablePipe):
return return
ops = self.model.ops ops = self.model.ops
doc_scores.append(ops.asarray1f(scores)) doc_scores.append(ops.asarray1f(scores))
doc_scores_lens.append(doc_scores[-1].shape[0])
doc_ents.append(ops.asarray1i(ents, dtype="uint64")) doc_ents.append(ops.asarray1i(ents, dtype="uint64"))