diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index ce4bc0138..e374cf94d 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union from cymem.cymem cimport Pool from .candidate import Candidate -from ..tokens import Span +from ..tokens import Span, SpanGroup from ..util import SimpleFrozenList from ..errors import Errors @@ -30,12 +30,12 @@ cdef class KnowledgeBase: self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: + def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]: """ Return candidate entities for specified texts. Each candidate defines the entity, the original alias, and the prior probability of that alias resolving to that entity. If no candidate is found for a given text, an empty list is returned. - mentions (Iterable[Span]): Mentions for which to get candidates. + mentions (SpanGroup): Mentions for which to get candidates. RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. """ return [self.get_candidates(span) for span in mentions] diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 165302c3b..44d3592c6 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -8,7 +8,7 @@ from ...util import registry from ...kb import KnowledgeBase, InMemoryLookupKB from ...kb import Candidate from ...vocab import Vocab -from ...tokens import Span, Doc +from ...tokens import Doc, Span, SpanGroup from ..extract_spans import extract_spans from ...errors import Errors @@ -106,7 +106,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: @registry.misc("spacy.CandidateBatchGenerator.v1") def create_candidates_batch() -> Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] ]: return get_candidates_batch @@ -122,12 +122,12 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: def get_candidates_batch( - kb: KnowledgeBase, mentions: Iterable[Span] + kb: KnowledgeBase, mentions: SpanGroup ) -> Iterable[Iterable[Candidate]]: """ Return candidate entities for the given mentions and fetching appropriate entries from the index. kb (KnowledgeBase): Knowledge base to query. - mention (Iterable[Span]): Entity mentions for which to identify candidates. + mention (SpanGroup): Entity mentions for which to identify candidates. RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. """ return kb.get_candidates_batch(mentions) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 907307056..a38f6b95b 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -9,9 +9,10 @@ import random from thinc.api import CosineDistance, Model, Optimizer, Config from thinc.api import set_dropout_rate +from ..tokens import SpanGroup from ..kb import KnowledgeBase, Candidate from ..ml import empty_kb -from ..tokens import Doc, Span +from ..tokens import Doc, Span, SpanGroup from .pipe import deserialize_config from .trainable_pipe import TrainablePipe from ..language import Language @@ -82,7 +83,7 @@ def make_entity_linker( entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] ], overwrite: bool, scorer: Optional[Callable], @@ -104,7 +105,7 @@ def make_entity_linker( get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that produces a list of candidates, given a certain knowledge base and a textual mention. get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] + Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]] ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another @@ -186,7 +187,7 @@ class EntityLinker(TrainablePipe): entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] ], overwrite: bool = False, scorer: Optional[Callable] = entity_linker_score, @@ -209,9 +210,9 @@ class EntityLinker(TrainablePipe): get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that produces a list of candidates, given a certain knowledge base and a textual mention. get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], + Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. overwrite (bool): Whether to overwrite existing non-empty annotations. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another @@ -485,7 +486,8 @@ class EntityLinker(TrainablePipe): batch_candidates = list( self.get_candidates_batch( - self.kb, [ent_batch[idx] for idx in valid_ent_idx] + self.kb, + SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]), ) if self.candidates_batch_size > 1 else [