This commit is contained in:
Adriane Boyd 2023-01-27 08:29:46 +01:00
parent 8548d4d16e
commit fd911fe2af
3 changed files with 16 additions and 5 deletions

View File

@ -168,7 +168,7 @@ class EditTreeLemmatizer(TrainablePipe):
student_scores: Scores representing the student model's predictions.
RETURNS (Tuple[float, float]): The loss and the gradient.
DOCS: https://spacy.io/api/edittreelemmatizer#get_teacher_student_loss
"""
loss_func = LegacySequenceCategoricalCrossentropy(normalize=False)

View File

@ -453,7 +453,11 @@ class EntityLinker(TrainablePipe):
docs_ents: List[Ragged] = []
docs_scores: List[Ragged] = []
if not docs:
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
return {
KNOWLEDGE_BASE_IDS: final_kb_ids,
"ents": docs_ents,
"scores": docs_scores,
}
if isinstance(docs, Doc):
docs = [docs]
for doc in docs:
@ -585,7 +589,11 @@ class EntityLinker(TrainablePipe):
method="predict", msg="result variables not of equal length"
)
raise RuntimeError(err)
return {KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, "scores": docs_scores}
return {
KNOWLEDGE_BASE_IDS: final_kb_ids,
"ents": docs_ents,
"scores": docs_scores,
}
def set_annotations(self, docs: Iterable[Doc], activations: ActivationsT) -> None:
"""Modify a batch of documents, using pre-computed scores.

View File

@ -252,8 +252,11 @@ class EntityRecognizer(Parser):
def labels(self):
# Get the labels from the model by looking at the available moves, e.g.
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
labels = set(remove_bilu_prefix(move) for move in self.move_names
if move[0] in ("B", "I", "L", "U"))
labels = set(
remove_bilu_prefix(move)
for move in self.move_names
if move[0] in ("B", "I", "L", "U")
)
return tuple(sorted(labels))
def scored_ents(self, beams):