mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-02 12:44:10 +03:00
_store_activations: make kwarg-only, remove doc_scores_lens arg
This commit is contained in:
parent
57caae8393
commit
ce36f345db
|
@ -433,7 +433,6 @@ class EntityLinker(TrainablePipe):
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
doc_ents: List[Ints1d] = []
|
doc_ents: List[Ints1d] = []
|
||||||
doc_scores: List[Floats1d] = []
|
doc_scores: List[Floats1d] = []
|
||||||
doc_scores_lens: List[int] = []
|
|
||||||
if len(doc) == 0:
|
if len(doc) == 0:
|
||||||
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
|
||||||
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
|
||||||
|
@ -460,7 +459,10 @@ class EntityLinker(TrainablePipe):
|
||||||
# ignoring this entity - setting to NIL
|
# ignoring this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
self._add_activations(
|
self._add_activations(
|
||||||
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
|
doc_scores=doc_scores,
|
||||||
|
doc_ents=doc_ents,
|
||||||
|
scores=[0.0],
|
||||||
|
ents=[0],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
candidates = list(self.get_candidates(self.kb, ent))
|
candidates = list(self.get_candidates(self.kb, ent))
|
||||||
|
@ -468,17 +470,19 @@ class EntityLinker(TrainablePipe):
|
||||||
# no prediction possible for this entity - setting to NIL
|
# no prediction possible for this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
self._add_activations(
|
self._add_activations(
|
||||||
doc_scores, doc_scores_lens, doc_ents, [0.0], [0]
|
doc_scores=doc_scores,
|
||||||
|
doc_ents=doc_ents,
|
||||||
|
scores=[0.0],
|
||||||
|
ents=[0],
|
||||||
)
|
)
|
||||||
elif len(candidates) == 1 and self.threshold is None:
|
elif len(candidates) == 1 and self.threshold is None:
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
self._add_activations(
|
self._add_activations(
|
||||||
doc_scores,
|
doc_scores=doc_scores,
|
||||||
doc_scores_lens,
|
doc_ents=doc_ents,
|
||||||
doc_ents,
|
scores=[1.0],
|
||||||
[1.0],
|
ents=[candidates[0].entity_],
|
||||||
[candidates[0].entity_],
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
@ -513,14 +517,16 @@ class EntityLinker(TrainablePipe):
|
||||||
else EntityLinker.NIL
|
else EntityLinker.NIL
|
||||||
)
|
)
|
||||||
self._add_activations(
|
self._add_activations(
|
||||||
doc_scores,
|
doc_scores=doc_scores,
|
||||||
doc_scores_lens,
|
doc_ents=doc_ents,
|
||||||
doc_ents,
|
scores=scores,
|
||||||
scores,
|
ents=[c.entity for c in candidates],
|
||||||
[c.entity for c in candidates],
|
|
||||||
)
|
)
|
||||||
self._add_doc_activations(
|
self._add_doc_activations(
|
||||||
docs_scores, docs_ents, doc_scores, doc_scores_lens, doc_ents
|
docs_scores=docs_scores,
|
||||||
|
docs_ents=docs_ents,
|
||||||
|
doc_scores=doc_scores,
|
||||||
|
doc_ents=doc_ents,
|
||||||
)
|
)
|
||||||
if not (len(final_kb_ids) == entity_count):
|
if not (len(final_kb_ids) == entity_count):
|
||||||
err = Errors.E147.format(
|
err = Errors.E147.format(
|
||||||
|
@ -656,28 +662,23 @@ class EntityLinker(TrainablePipe):
|
||||||
|
|
||||||
def _add_doc_activations(
|
def _add_doc_activations(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
docs_scores: List[Ragged],
|
docs_scores: List[Ragged],
|
||||||
docs_ents: List[Ragged],
|
docs_ents: List[Ragged],
|
||||||
doc_scores: List[Floats1d],
|
doc_scores: List[Floats1d],
|
||||||
doc_scores_lens: List[int],
|
|
||||||
doc_ents: List[Ints1d],
|
doc_ents: List[Ints1d],
|
||||||
):
|
):
|
||||||
if len(self.store_activations) == 0:
|
if len(self.store_activations) == 0:
|
||||||
return
|
return
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
docs_scores.append(
|
lengths = ops.asarray1i([s.shape[0] for s in doc_scores])
|
||||||
Ragged(ops.flatten(doc_scores), ops.asarray1i(doc_scores_lens))
|
docs_scores.append(Ragged(ops.flatten(doc_scores), lengths))
|
||||||
)
|
docs_ents.append(Ragged(ops.flatten(doc_ents), lengths))
|
||||||
docs_ents.append(
|
|
||||||
Ragged(
|
|
||||||
ops.flatten(doc_ents, dtype="uint64"), ops.asarray1i(doc_scores_lens)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
def _add_activations(
|
def _add_activations(
|
||||||
self,
|
self,
|
||||||
|
*,
|
||||||
doc_scores: List[Floats1d],
|
doc_scores: List[Floats1d],
|
||||||
doc_scores_lens: List[int],
|
|
||||||
doc_ents: List[Ints1d],
|
doc_ents: List[Ints1d],
|
||||||
scores: Sequence[float],
|
scores: Sequence[float],
|
||||||
ents: Sequence[int],
|
ents: Sequence[int],
|
||||||
|
@ -686,5 +687,4 @@ class EntityLinker(TrainablePipe):
|
||||||
return
|
return
|
||||||
ops = self.model.ops
|
ops = self.model.ops
|
||||||
doc_scores.append(ops.asarray1f(scores))
|
doc_scores.append(ops.asarray1f(scores))
|
||||||
doc_scores_lens.append(doc_scores[-1].shape[0])
|
|
||||||
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))
|
doc_ents.append(ops.asarray1i(ents, dtype="uint64"))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user