Change type for mentions to look up entity candidates for to SpanGroup from Iterable[Span].

This commit is contained in:
Raphael Mitsch 2023-02-28 15:28:05 +01:00
parent a97ef65b33
commit 8596fb8b88
3 changed files with 16 additions and 14 deletions

View File

@ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union
from cymem.cymem cimport Pool
from .candidate import Candidate
from ..tokens import Span
from ..tokens import Span, SpanGroup
from ..util import SimpleFrozenList
from ..errors import Errors
@ -30,12 +30,12 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length
self.mem = Pool()
def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]:
def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]:
"""
Return candidate entities for specified texts. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If no candidate is found for a given text, an empty list is returned.
mentions (Iterable[Span]): Mentions for which to get candidates.
mentions (SpanGroup): Mentions for which to get candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
"""
return [self.get_candidates(span) for span in mentions]

View File

@ -8,7 +8,7 @@ from ...util import registry
from ...kb import KnowledgeBase, InMemoryLookupKB
from ...kb import Candidate
from ...vocab import Vocab
from ...tokens import Span, Doc
from ...tokens import Doc, Span, SpanGroup
from ..extract_spans import extract_spans
from ...errors import Errors
@ -106,7 +106,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
@registry.misc("spacy.CandidateBatchGenerator.v1")
def create_candidates_batch() -> Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
]:
return get_candidates_batch
@ -122,12 +122,12 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
def get_candidates_batch(
kb: KnowledgeBase, mentions: Iterable[Span]
kb: KnowledgeBase, mentions: SpanGroup
) -> Iterable[Iterable[Candidate]]:
"""
Return candidate entities for the given mentions and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query.
mention (Iterable[Span]): Entity mentions for which to identify candidates.
mention (SpanGroup): Entity mentions for which to identify candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
"""
return kb.get_candidates_batch(mentions)

View File

@ -9,9 +9,10 @@ import random
from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate
from ..tokens import SpanGroup
from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc, Span
from ..tokens import Doc, Span, SpanGroup
from .pipe import deserialize_config
from .trainable_pipe import TrainablePipe
from ..language import Language
@ -82,7 +83,7 @@ def make_entity_linker(
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
],
overwrite: bool,
scorer: Optional[Callable],
@ -104,7 +105,7 @@ def make_entity_linker(
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
@ -186,7 +187,7 @@ class EntityLinker(TrainablePipe):
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
],
overwrite: bool = False,
scorer: Optional[Callable] = entity_linker_score,
@ -209,9 +210,9 @@ class EntityLinker(TrainablePipe):
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]],
Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
overwrite (bool): Whether to overwrite existing non-empty annotations.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
@ -485,7 +486,8 @@ class EntityLinker(TrainablePipe):
batch_candidates = list(
self.get_candidates_batch(
self.kb, [ent_batch[idx] for idx in valid_ent_idx]
self.kb,
SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]),
)
if self.candidates_batch_size > 1
else [