rename to _ensure_ents

This commit is contained in:
svlandeg 2024-04-02 09:32:52 +02:00
parent ce7b51ef06
commit 1a2c379e93

View File

@ -247,7 +247,7 @@ class EntityLinker(TrainablePipe):
if not self.use_gold_ents: if not self.use_gold_ents:
return scorer(examples, **kwargs) return scorer(examples, **kwargs)
else: else:
examples = self._augment_examples(examples) examples = self._ensure_ents(examples)
docs = self.pipe( docs = self.pipe(
(eg.predicted for eg in examples), (eg.predicted for eg in examples),
) )
@ -257,7 +257,7 @@ class EntityLinker(TrainablePipe):
self.scorer = _score_augmented self.scorer = _score_augmented
def _augment_examples(self, examples: Iterable[Example]) -> Iterable[Example]: def _ensure_ents(self, examples: Iterable[Example]) -> Iterable[Example]:
"""If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted.""" """If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted."""
if not self.use_gold_ents: if not self.use_gold_ents:
return examples return examples
@ -311,7 +311,7 @@ class EntityLinker(TrainablePipe):
nO = self.kb.entity_vector_length nO = self.kb.entity_vector_length
doc_sample = [] doc_sample = []
vector_sample = [] vector_sample = []
examples = self._augment_examples(islice(get_examples(), 10)) examples = self._ensure_ents(islice(get_examples(), 10))
for eg in examples: for eg in examples:
doc = eg.x doc = eg.x
doc_sample.append(doc) doc_sample.append(doc)
@ -379,7 +379,7 @@ class EntityLinker(TrainablePipe):
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
if not examples: if not examples:
return losses return losses
examples = self._augment_examples(examples) examples = self._ensure_ents(examples)
validate_examples(examples, "EntityLinker.update") validate_examples(examples, "EntityLinker.update")
# make sure we have something to learn from, if not, short-circuit # make sure we have something to learn from, if not, short-circuit