diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 71458fe3f..62a4053e0 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -433,7 +433,6 @@ class EntityLinker(TrainablePipe): for doc in docs: doc_ents: List[Ints1d] = [] doc_scores: List[Floats1d] = [] - doc_scores_lens: List[int] = [] if len(doc) == 0: docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0))) docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0))) @@ -460,7 +459,10 @@ class EntityLinker(TrainablePipe): # ignoring this entity - setting to NIL final_kb_ids.append(self.NIL) self._add_activations( - doc_scores, doc_scores_lens, doc_ents, [0.0], [0] + doc_scores=doc_scores, + doc_ents=doc_ents, + scores=[0.0], + ents=[0], ) else: candidates = list(self.get_candidates(self.kb, ent)) @@ -468,17 +470,19 @@ class EntityLinker(TrainablePipe): # no prediction possible for this entity - setting to NIL final_kb_ids.append(self.NIL) self._add_activations( - doc_scores, doc_scores_lens, doc_ents, [0.0], [0] + doc_scores=doc_scores, + doc_ents=doc_ents, + scores=[0.0], + ents=[0], ) elif len(candidates) == 1 and self.threshold is None: # shortcut for efficiency reasons: take the 1 candidate final_kb_ids.append(candidates[0].entity_) self._add_activations( - doc_scores, - doc_scores_lens, - doc_ents, - [1.0], - [candidates[0].entity_], + doc_scores=doc_scores, + doc_ents=doc_ents, + scores=[1.0], + ents=[candidates[0].entity_], ) else: random.shuffle(candidates) @@ -513,14 +517,16 @@ class EntityLinker(TrainablePipe): else EntityLinker.NIL ) self._add_activations( - doc_scores, - doc_scores_lens, - doc_ents, - scores, - [c.entity for c in candidates], + doc_scores=doc_scores, + doc_ents=doc_ents, + scores=scores, + ents=[c.entity for c in candidates], ) self._add_doc_activations( - docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents + docs_scores=docs_scores, + docs_ents=docs_ents, + doc_scores=doc_scores, + doc_ents=doc_ents, ) if not (len(final_kb_ids) == entity_count): err = Errors.E147.format( @@ -656,28 +662,23 @@ class EntityLinker(TrainablePipe): def _add_doc_activations( self, + *, docs_scores: List[Ragged], docs_ents: List[Ragged], doc_scores: List[Floats1d], - doc_scores_lens: List[int], doc_ents: List[Ints1d], ): if len(self.store_activations) == 0: return ops = self.model.ops - docs_scores.append( - Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens)) - ) - docs_ents.append( - Ragged( - ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens) - ) - ) + lengths = ops.asarray1i([s.shape[0] for s in doc_scores]) + docs_scores.append(Ragged(ops.flatten(doc_scores), lengths)) + docs_ents.append(Ragged(ops.flatten(doc_ents), lengths)) def _add_activations( self, + *, doc_scores: List[Floats1d], - doc_scores_lens: List[int], doc_ents: List[Ints1d], scores: Sequence[float], ents: Sequence[int], @@ -686,5 +687,4 @@ class EntityLinker(TrainablePipe): return ops = self.model.ops doc_scores.append(ops.asarray1f(scores)) - doc_scores_lens.append(doc_scores[-1].shape[0]) doc_ents.append(ops.asarray1i(ents, dtype="uint64"))