mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-30 18:03:04 +03:00
Change type for mentions to look up entity candidates for to SpanGroup from Iterable[Span].
This commit is contained in:
parent
a97ef65b33
commit
8596fb8b88
|
@ -5,7 +5,7 @@ from typing import Iterable, Tuple, Union
|
|||
from cymem.cymem cimport Pool
|
||||
|
||||
from .candidate import Candidate
|
||||
from ..tokens import Span
|
||||
from ..tokens import Span, SpanGroup
|
||||
from ..util import SimpleFrozenList
|
||||
from ..errors import Errors
|
||||
|
||||
|
@ -30,12 +30,12 @@ cdef class KnowledgeBase:
|
|||
self.entity_vector_length = entity_vector_length
|
||||
self.mem = Pool()
|
||||
|
||||
def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]:
|
||||
def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]:
|
||||
"""
|
||||
Return candidate entities for specified texts. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If no candidate is found for a given text, an empty list is returned.
|
||||
mentions (Iterable[Span]): Mentions for which to get candidates.
|
||||
mentions (SpanGroup): Mentions for which to get candidates.
|
||||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
|
||||
"""
|
||||
return [self.get_candidates(span) for span in mentions]
|
||||
|
|
|
@ -8,7 +8,7 @@ from ...util import registry
|
|||
from ...kb import KnowledgeBase, InMemoryLookupKB
|
||||
from ...kb import Candidate
|
||||
from ...vocab import Vocab
|
||||
from ...tokens import Span, Doc
|
||||
from ...tokens import Doc, Span, SpanGroup
|
||||
from ..extract_spans import extract_spans
|
||||
from ...errors import Errors
|
||||
|
||||
|
@ -106,7 +106,7 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
|
|||
|
||||
@registry.misc("spacy.CandidateBatchGenerator.v1")
|
||||
def create_candidates_batch() -> Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
|
||||
]:
|
||||
return get_candidates_batch
|
||||
|
||||
|
@ -122,12 +122,12 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
|
|||
|
||||
|
||||
def get_candidates_batch(
|
||||
kb: KnowledgeBase, mentions: Iterable[Span]
|
||||
kb: KnowledgeBase, mentions: SpanGroup
|
||||
) -> Iterable[Iterable[Candidate]]:
|
||||
"""
|
||||
Return candidate entities for the given mentions and fetching appropriate entries from the index.
|
||||
kb (KnowledgeBase): Knowledge base to query.
|
||||
mention (Iterable[Span]): Entity mentions for which to identify candidates.
|
||||
mention (SpanGroup): Entity mentions for which to identify candidates.
|
||||
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
|
||||
"""
|
||||
return kb.get_candidates_batch(mentions)
|
||||
|
|
|
@ -9,9 +9,10 @@ import random
|
|||
from thinc.api import CosineDistance, Model, Optimizer, Config
|
||||
from thinc.api import set_dropout_rate
|
||||
|
||||
from ..tokens import SpanGroup
|
||||
from ..kb import KnowledgeBase, Candidate
|
||||
from ..ml import empty_kb
|
||||
from ..tokens import Doc, Span
|
||||
from ..tokens import Doc, Span, SpanGroup
|
||||
from .pipe import deserialize_config
|
||||
from .trainable_pipe import TrainablePipe
|
||||
from ..language import Language
|
||||
|
@ -82,7 +83,7 @@ def make_entity_linker(
|
|||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
overwrite: bool,
|
||||
scorer: Optional[Callable],
|
||||
|
@ -104,7 +105,7 @@ def make_entity_linker(
|
|||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_batch (
|
||||
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
|
||||
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
scorer (Optional[Callable]): The scoring method.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
|
@ -186,7 +187,7 @@ class EntityLinker(TrainablePipe):
|
|||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
||||
get_candidates_batch: Callable[
|
||||
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
||||
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
|
||||
],
|
||||
overwrite: bool = False,
|
||||
scorer: Optional[Callable] = entity_linker_score,
|
||||
|
@ -209,9 +210,9 @@ class EntityLinker(TrainablePipe):
|
|||
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
get_candidates_batch (
|
||||
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
|
||||
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]],
|
||||
Iterable[Candidate]]
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
||||
overwrite (bool): Whether to overwrite existing non-empty annotations.
|
||||
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
||||
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
||||
|
@ -485,7 +486,8 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
batch_candidates = list(
|
||||
self.get_candidates_batch(
|
||||
self.kb, [ent_batch[idx] for idx in valid_ent_idx]
|
||||
self.kb,
|
||||
SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]),
|
||||
)
|
||||
if self.candidates_batch_size > 1
|
||||
else [
|
||||
|
|
Loading…
Reference in New Issue
Block a user