Modify EL batching to doc-wise streaming approach (#12367)

* Convert Candidate from Cython to Python class.

* Format.

* Fix .entity_ typo in _add_activations() usage.

* Change type for mentions to look up entity candidates for to SpanGroup from Iterable[Span].

* Update docs.

* Update spacy/kb/candidate.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update doc string of BaseCandidate.__init__().

* Update spacy/kb/candidate.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Rename Candidate to InMemoryCandidate, BaseCandidate to Candidate.

* Adjust Candidate to support and mandate numerical entity IDs.

* Format.

* Fix docstring and docs.

* Update website/docs/api/kb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Rename alias -> mention.

* Refactor Candidate attribute names. Update docs and tests accordingly.

* Refacor Candidate attributes and their usage.

* Format.

* Fix mypy error.

* Update error code in line with v4 convention.

* Modify EL batching system.

* Update leftover get_candidates() mention in docs.

* Format docs.

* Format.

* Update spacy/kb/candidate.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Updated error code.

* Simplify interface for int/str representations.

* Update website/docs/api/kb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Rename 'alias' to 'mention'.

* Port Candidate and InMemoryCandidate to Cython.

* Remove redundant entry in setup.py.

* Add abstract class check.

* Drop storing mention.

* Update spacy/kb/candidate.pxd

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Fix entity_id refactoring problems in docstrings.

* Drop unused InMemoryCandidate._entity_hash.

* Update docstrings.

* Move attributes out of Candidate.

* Partially fix alias/mention terminology usage. Convert Candidate to interface.

* Remove prior_prob from supported properties in Candidate. Introduce KnowledgeBase.supports_prior_probs().

* Update docstrings related to prior_prob.

* Update alias/mention usage in doc(strings).

* Update spacy/ml/models/entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/ml/models/entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Mention -> alias renaming. Drop Candidate.mentions(). Drop InMemoryLookupKB.get_alias_candidates() from docs.

* Update docstrings.

* Fix InMemoryCandidate attribute names.

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/ml/models/entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update W401 test.

* Update spacy/errors.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Use Candidate output type for toy generators in the test suite to mimick best practices

* fix docs

* fix import

* Fix merge leftovers.

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/kb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/entitylinker.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/kb/kb_in_memory.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/inmemorylookupkb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update get_candidates() docstring.

* Reformat imports in entity_linker.py.

* Drop valid_ent_idx_per_doc.

* Update docs.

* Format.

* Simplify doc loop in predict().

* Remove E1044 comment.

* Fix merge errors.

* Format.

* Format.

* Format.

* Fix merge error & tests.

* Format.

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Use type alias.

* isort.

* isort.

* Lint.

* Add typedefs.pyx.

* Fix typedef import.

* Fix type aliases.

* Format.

* Update docstring and type usage.

* Add info on get_candidates(), get_candidates_batched().

* Readd get_candidates info to v3 changelog.

* Update website/docs/api/entitylinker.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update factory functions for backwards compatibility.

* Format.

* Ignore mypy error.

* Fix mypy error.

* Format.

* Add test for multiple docs with multiple entities.

---------

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
Raphael Mitsch 2024-04-09 11:39:18 +02:00 committed by GitHub
parent afb22ad491
commit 304b9331e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 344 additions and 296 deletions

View File

@ -238,7 +238,7 @@ grad_factor = 1.0
{% if "entity_linker" in components -%} {% if "entity_linker" in components -%}
[components.entity_linker] [components.entity_linker]
factory = "entity_linker" factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"} get_candidates = {"@misc":"spacy.CandidateGenerator.v2"}
incl_context = true incl_context = true
incl_prior = true incl_prior = true
@ -517,7 +517,7 @@ width = ${components.tok2vec.model.encode.width}
{% if "entity_linker" in components -%} {% if "entity_linker" in components -%}
[components.entity_linker] [components.entity_linker]
factory = "entity_linker" factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"} get_candidates = {"@misc":"spacy.CandidateGenerator.v2"}
incl_context = true incl_context = true
incl_prior = true incl_prior = true

View File

@ -950,7 +950,6 @@ class Errors(metaclass=ErrorsWithCodes):
"case pass an empty list for the previously not specified argument to avoid this error.") "case pass an empty list for the previously not specified argument to avoid this error.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.") "{value}.")
E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}")
E1045 = ("Encountered {parent} subclass without `{parent}.{method}` " E1045 = ("Encountered {parent} subclass without `{parent}.{method}` "
"method in '{name}'. If you want to use this method, make " "method in '{name}'. If you want to use this method, make "
"sure it's overwritten on the subclass.") "sure it's overwritten on the subclass.")

View File

@ -1,14 +1,14 @@
# cython: infer_types=True # cython: infer_types=True
from pathlib import Path from pathlib import Path
from typing import Iterable, Tuple, Union from typing import Iterable, Iterator, Tuple, Union
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from ..errors import Errors from ..errors import Errors
from ..tokens import Span, SpanGroup from ..tokens import SpanGroup
from ..util import SimpleFrozenList from ..util import SimpleFrozenList
from .candidate import Candidate from .candidate cimport Candidate
cdef class KnowledgeBase: cdef class KnowledgeBase:
@ -19,6 +19,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
CandidatesForMentionT = Iterable[Candidate]
CandidatesForDocT = Iterable[CandidatesForMentionT]
def __init__(self, vocab: Vocab, entity_vector_length: int): def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase.""" """Create a KnowledgeBase."""
@ -32,27 +34,15 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self.mem = Pool() self.mem = Pool()
def get_candidates_batch( def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[CandidatesForDocT]:
self, mentions: SpanGroup
) -> Iterable[Iterable[Candidate]]:
""" """
Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the Return candidate entities for the specified groups of mentions (as SpanGroup) per Doc.
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior Each candidate for a mention defines at least the entity and the entity's embedding vector. Depending on the KB
probability of the specified mention text resolving to that entity - might be included. implementation, further properties - such as the prior probability of the specified mention text resolving to
that entity - might be included.
If no candidates are found for a given mention, an empty list is returned. If no candidates are found for a given mention, an empty list is returned.
mentions (SpanGroup): Mentions for which to get candidates. mentions (Iterator[SpanGroup]): Mentions for which to get candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per mention/doc/doc batch.
"""
return [self.get_candidates(span) for span in mentions]
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
"""
Return candidate entities for a specific mention. Each candidate defines at least the entity and the
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included.
If no candidate is found for the given mention, an empty list is returned.
mention (Span): Mention for which to get candidates.
RETURNS (Iterable[Candidate]): Identified candidates.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format( Errors.E1045.format(

View File

@ -1,5 +1,5 @@
# cython: infer_types=True # cython: infer_types=True
from typing import Any, Callable, Dict, Iterable from typing import Any, Callable, Dict, Iterable, Iterator
import srsly import srsly
@ -12,7 +12,7 @@ from preshed.maps cimport PreshMap
import warnings import warnings
from pathlib import Path from pathlib import Path
from ..tokens import Span from ..tokens import SpanGroup
from ..typedefs cimport hash_t from ..typedefs cimport hash_t
@ -255,8 +255,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
alias_entry.probs = probs alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry self._aliases_table[alias_index] = alias_entry
def get_candidates(self, mention: Span) -> Iterable[InMemoryCandidate]: def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[InMemoryCandidate]]]:
return self._get_alias_candidates(mention.text) # type: ignore for mentions_for_doc in mentions:
yield [self._get_alias_candidates(span.text) for span in mentions_for_doc]
def _get_alias_candidates(self, str alias) -> Iterable[InMemoryCandidate]: def _get_alias_candidates(self, str alias) -> Iterable[InMemoryCandidate]:
""" """

View File

@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Callable, Iterable, List, Optional, Tuple from typing import Callable, Iterable, Iterator, List, Optional, Tuple
from thinc.api import ( from thinc.api import (
Linear, Linear,
@ -21,6 +21,9 @@ from ...util import registry
from ...vocab import Vocab from ...vocab import Vocab
from ..extract_spans import extract_spans from ..extract_spans import extract_spans
CandidatesForMentionT = Iterable[Candidate]
CandidatesForDocT = Iterable[CandidatesForMentionT]
@registry.architectures("spacy.EntityLinker.v2") @registry.architectures("spacy.EntityLinker.v2")
def build_nel_encoder( def build_nel_encoder(
@ -117,34 +120,38 @@ def empty_kb(
@registry.misc("spacy.CandidateGenerator.v1") @registry.misc("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: def create_get_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates return get_candidates
@registry.misc("spacy.CandidateBatchGenerator.v1") @registry.misc("spacy.CandidateGenerator.v2")
def create_candidates_batch() -> Callable[ def create_get_candidates_v2() -> Callable[
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] [KnowledgeBase, Iterator[SpanGroup]], Iterator[CandidatesForDocT]
]: ]:
return get_candidates_batch return get_candidates_v2
def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
""" """
Return candidate entities for a given mention and fetching appropriate entries from the index. Return candidate entities for the given mention from the KB.
kb (KnowledgeBase): Knowledge base to query. kb (KnowledgeBase): Knowledge base to query.
mention (Span): Entity mention for which to identify candidates. mention (Span): Entity mention.
RETURNS (Iterable[Candidate]): Identified candidates. RETURNS (Iterable[Candidate]): Identified candidates for specified mention.
""" """
return kb.get_candidates(mention) cands_per_doc = next(
get_candidates_v2(kb, iter([SpanGroup(mention.doc, spans=[mention])]))
)
assert isinstance(cands_per_doc, list)
return next(cands_per_doc[0])
def get_candidates_batch( def get_candidates_v2(
kb: KnowledgeBase, mentions: SpanGroup kb: KnowledgeBase, mentions: Iterator[SpanGroup]
) -> Iterable[Iterable[Candidate]]: ) -> Iterator[Iterable[Iterable[Candidate]]]:
""" """
Return candidate entities for the given mentions and fetching appropriate entries from the index. Return candidate entities for the given mentions from the KB.
kb (KnowledgeBase): Knowledge base to query. kb (KnowledgeBase): Knowledge base to query.
mentions (SpanGroup): Entity mentions for which to identify candidates. mentions (Iterator[SpanGroup]): Mentions per doc.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per mentions in document/SpanGroup.
""" """
return kb.get_candidates_batch(mentions) return kb.get_candidates(mentions)

View File

@ -1,8 +1,19 @@
import random import random
import warnings import warnings
from itertools import islice from itertools import islice, tee
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
import srsly import srsly
from numpy import dtype from numpy import dtype
@ -54,13 +65,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"entity_vector_length": 64, "entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.CandidateGenerator.v2"},
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
"overwrite": False, "overwrite": False,
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True, "use_gold_ents": True,
"candidates_batch_size": 1,
"threshold": None, "threshold": None,
"save_activations": False, "save_activations": False,
}, },
@ -80,15 +89,13 @@ def make_entity_linker(
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates: Callable[
get_candidates_batch: Callable[ [KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
use_gold_ents: bool, use_gold_ents: bool,
candidates_batch_size: int,
threshold: Optional[float] = None, threshold: Optional[float] = None,
save_activations: bool, save_activations: bool,
): ):
@ -102,16 +109,13 @@ def make_entity_linker(
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
produces a list of candidates, given a certain knowledge base and a textual mention. Function producing a list of candidates per document, given a certain knowledge base and several textual
get_candidates_batch ( documents with textual mentions.
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
prediction is discarded. If None, predictions are not filtered by any threshold. prediction is discarded. If None, predictions are not filtered by any threshold.
save_activations (bool): save model activations in Doc when annotating. save_activations (bool): save model activations in Doc when annotating.
@ -129,12 +133,10 @@ def make_entity_linker(
incl_context=incl_context, incl_context=incl_context,
entity_vector_length=entity_vector_length, entity_vector_length=entity_vector_length,
get_candidates=get_candidates, get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch,
generate_empty_kb=generate_empty_kb, generate_empty_kb=generate_empty_kb,
overwrite=overwrite, overwrite=overwrite,
scorer=scorer, scorer=scorer,
use_gold_ents=use_gold_ents, use_gold_ents=use_gold_ents,
candidates_batch_size=candidates_batch_size,
threshold=threshold, threshold=threshold,
save_activations=save_activations, save_activations=save_activations,
) )
@ -168,15 +170,14 @@ class EntityLinker(TrainablePipe):
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates: Callable[
get_candidates_batch: Callable[ [KnowledgeBase, Iterator[SpanGroup]],
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] Iterator[Iterable[Iterable[Candidate]]],
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase], generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = False, overwrite: bool = False,
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool, use_gold_ents: bool,
candidates_batch_size: int,
threshold: Optional[float] = None, threshold: Optional[float] = None,
save_activations: bool = False, save_activations: bool = False,
) -> None: ) -> None:
@ -191,18 +192,14 @@ class EntityLinker(TrainablePipe):
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
produces a list of candidates, given a certain knowledge base and a textual mention. Function producing a list of candidates per document, given a certain knowledge base and several textual
get_candidates_batch ( documents with textual mentions.
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]],
Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase. generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
overwrite (bool): Whether to overwrite existing non-empty annotations. overwrite (bool): Whether to overwrite existing non-empty annotations.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
threshold, prediction is discarded. If None, predictions are not filtered by any threshold. threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
save_activations (bool): save model activations in Doc when annotating. save_activations (bool): save model activations in Doc when annotating.
@ -227,18 +224,14 @@ class EntityLinker(TrainablePipe):
self.incl_prior = incl_prior self.incl_prior = incl_prior
self.incl_context = incl_context self.incl_context = incl_context
self.get_candidates = get_candidates self.get_candidates = get_candidates
self.get_candidates_batch = get_candidates_batch
self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
self.kb = generate_empty_kb(self.vocab, entity_vector_length) self.kb = generate_empty_kb(self.vocab, entity_vector_length)
self.scorer = scorer self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.candidates_batch_size = candidates_batch_size
self.threshold = threshold self.threshold = threshold
self.save_activations = save_activations self.save_activations = save_activations
if candidates_batch_size < 1:
raise ValueError(Errors.E1044)
if self.incl_prior and not self.kb.supports_prior_probs: if self.incl_prior and not self.kb.supports_prior_probs:
warnings.warn(Warnings.W401) warnings.warn(Warnings.W401)
@ -318,11 +311,12 @@ class EntityLinker(TrainablePipe):
If one isn't present, then the update step needs to be skipped. If one isn't present, then the update step needs to be skipped.
""" """
for candidates_for_doc in self.get_candidates(
for eg in examples: self.kb,
for ent in eg.predicted.ents: (SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples),
candidates = list(self.get_candidates(self.kb, ent)) ):
if candidates: for candidates_for_mention in candidates_for_doc:
if list(candidates_for_mention):
return True return True
return False return False
@ -451,40 +445,35 @@ class EntityLinker(TrainablePipe):
} }
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for doc in docs:
docs_iters = tee(docs, 2)
# Call candidate generator.
all_ent_cands = self.get_candidates(
self.kb,
(
SpanGroup(
doc,
spans=[
ent for ent in doc.ents if ent.label_ not in self.labels_discard
],
)
for doc in docs_iters[0]
),
)
for doc in docs_iters[1]:
doc_ents: List[Ints1d] = [] doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = [] doc_scores: List[Floats1d] = []
if len(doc) == 0: if len(doc) == 0 or len(doc.ents) == 0:
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0))) docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0))) docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
continue continue
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
doc_ent_cands = list(next(all_ent_cands))
# Loop over entities in batches. # Looping over candidate entities for this doc. (TODO: rewrite)
for ent_idx in range(0, len(doc.ents), self.candidates_batch_size): for ent_cand_idx, ent in enumerate(doc.ents):
ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size]
# Look up candidate entities.
valid_ent_idx = [
idx
for idx in range(len(ent_batch))
if ent_batch[idx].label_ not in self.labels_discard
]
batch_candidates = list(
self.get_candidates_batch(
self.kb,
SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]),
)
if self.candidates_batch_size > 1
else [
self.get_candidates(self.kb, ent_batch[idx])
for idx in valid_ent_idx
]
)
# Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch):
assert hasattr(ent, "sents") assert hasattr(ent, "sents")
sents = list(ent.sents) sents = list(ent.sents)
sent_indices = ( sent_indices = (
@ -502,7 +491,6 @@ class EntityLinker(TrainablePipe):
start_token = sentences[start_sentence].start start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0] sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T sentence_encoding_t = sentence_encoding.T
@ -518,7 +506,7 @@ class EntityLinker(TrainablePipe):
ents=[0], ents=[0],
) )
else: else:
candidates = list(batch_candidates[j]) candidates = list(doc_ent_cands[ent_cand_idx])
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
@ -540,11 +528,12 @@ class EntityLinker(TrainablePipe):
else: else:
random.shuffle(candidates) random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False # set all prior probabilities to 0 if incl_prior=False
if self.incl_prior and self.kb.supports_prior_probs: scores = prior_probs = xp.asarray(
prior_probs = xp.asarray([c.prior_prob for c in candidates]) # type: ignore [
else: c.prior_prob if self.incl_prior else 0.0
prior_probs = xp.asarray([0.0 for _ in candidates]) for c in candidates
scores = prior_probs ]
)
# add in similarity from the context # add in similarity from the context
if self.incl_context: if self.incl_context:
entity_encodings = xp.asarray( entity_encodings = xp.asarray(
@ -567,8 +556,7 @@ class EntityLinker(TrainablePipe):
scores = prior_probs + sims - (prior_probs * sims) scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append( final_kb_ids.append(
candidates[scores.argmax().item()].entity_id_ candidates[scores.argmax().item()].entity_id_
if self.threshold is None if self.threshold is None or scores.max() >= self.threshold
or scores.max() >= self.threshold
else EntityLinker.NIL else EntityLinker.NIL
) )
self._add_activations( self._add_activations(
@ -577,6 +565,7 @@ class EntityLinker(TrainablePipe):
scores=scores, scores=scores,
ents=[c.entity_id for c in candidates], ents=[c.entity_id for c in candidates],
) )
self._add_doc_activations( self._add_doc_activations(
docs_scores=docs_scores, docs_scores=docs_scores,
docs_ents=docs_ents, docs_ents=docs_ents,
@ -588,6 +577,7 @@ class EntityLinker(TrainablePipe):
method="predict", msg="result variables not of equal length" method="predict", msg="result variables not of equal length"
) )
raise RuntimeError(err) raise RuntimeError(err)
return { return {
KNOWLEDGE_BASE_IDS: final_kb_ids, KNOWLEDGE_BASE_IDS: final_kb_ids,
"ents": docs_ents, "ents": docs_ents,

View File

@ -1,21 +1,21 @@
from typing import Any, Callable, Dict, Iterable, cast from typing import Any, Callable, Dict, Iterable, Iterator, cast
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
from thinc.types import Ragged from thinc.types import Ragged
from spacy import Language, registry, util from spacy import registry, util
from spacy.attrs import ENT_KB_ID from spacy.attrs import ENT_KB_ID
from spacy.compat import pickle from spacy.compat import pickle
from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase
from spacy.lang.en import English from spacy.lang.en import English
from spacy.ml import load_kb from spacy.ml import load_kb
from spacy.ml.models.entity_linker import build_span_maker, get_candidates from spacy.ml.models.entity_linker import build_span_maker, get_candidates_v2
from spacy.pipeline import EntityLinker, TrainablePipe from spacy.pipeline import EntityLinker, TrainablePipe
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer from spacy.scorer import Scorer
from spacy.tests.util import make_tempdir from spacy.tests.util import make_tempdir
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span, SpanGroup
from spacy.training import Example from spacy.training import Example
from spacy.util import ensure_path from spacy.util import ensure_path
from spacy.vocab import Vocab from spacy.vocab import Vocab
@ -453,11 +453,21 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
adam_ent_cands = get_candidates(mykb, adam_ent) adam_ent_cands = next(
assert len(get_candidates(mykb, douglas_ent)) == 2 get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[adam_ent]))
)[0]
assert len(adam_ent_cands) == 1 assert len(adam_ent_cands) == 1
assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive assert (
assert len(get_candidates(mykb, shrubbery_ent)) == 0 len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0])
== 2
)
assert (
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0
) # default case sensitive
assert (
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0])
== 0
)
# test the content of the candidates # test the content of the candidates
assert adam_ent_cands[0].entity_id_ == "Q2" assert adam_ent_cands[0].entity_id_ == "Q2"
@ -466,6 +476,86 @@ def test_candidate_generation(nlp):
assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9) assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9)
def test_candidate_generation_multiple_docs(nlp):
"""Test correct candidate generation with multiple docs."""
mykb = InMemoryLookupKB(nlp.vocab, entity_vector_length=1)
docs = [nlp("douglas adam Adam shrubbery"), nlp("shrubbery Adam douglas adam")]
douglas_ents = [docs[0][0:1], docs[1][2:3]]
adam_ents = [docs[0][1:2], docs[1][3:4]]
Adam_ents = [docs[0][2:3], docs[1][1:2]]
shrubbery_ents = [docs[0][3:4], docs[1][0:1]]
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
adam_ent_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[adam_ents[0]]),
SpanGroup(doc=docs[1], spans=[adam_ents[1]]),
],
)
)
assert len(adam_ent_cands) == 2
assert (
len(
list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[douglas_ents[0]]),
SpanGroup(doc=docs[1], spans=[douglas_ents[1]]),
],
)
)
)
== 2
)
Adam_ent_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[Adam_ents[0]]),
SpanGroup(doc=docs[1], spans=[Adam_ents[1]]),
],
)
)
assert len(Adam_ent_cands) == 2
assert (
len(Adam_ent_cands[0][0]) == 0 and len(Adam_ent_cands[1][0]) == 0
) # default case sensitive
shrubbery_ents_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[shrubbery_ents[0]]),
SpanGroup(doc=docs[1], spans=[shrubbery_ents[1]]),
],
)
)
assert len(shrubbery_ents_cands) == 2
assert len(shrubbery_ents_cands[0][0]) == 0 and len(shrubbery_ents_cands[1][0]) == 0
# test the content of the candidates
assert (
adam_ent_cands[0][0][0].entity_id_ == adam_ent_cands[1][0][0].entity_id_ == "Q2"
)
assert adam_ent_cands[0][0][0].alias == adam_ent_cands[1][0][0].alias == "adam"
assert_almost_equal(adam_ent_cands[0][0][0].entity_freq, 12)
assert_almost_equal(adam_ent_cands[1][0][0].entity_freq, 12)
assert_almost_equal(adam_ent_cands[0][0][0].prior_prob, 0.9)
assert_almost_equal(adam_ent_cands[1][0][0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp): def test_el_pipe_configuration(nlp):
"""Test correct candidate generation as part of the EL pipe""" """Test correct candidate generation as part of the EL pipe"""
nlp.add_pipe("sentencizer") nlp.add_pipe("sentencizer")
@ -490,24 +580,20 @@ def test_el_pipe_configuration(nlp):
assert doc[1].ent_kb_id_ == "" assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2" assert doc[2].ent_kb_id_ == "Q2"
def get_lowercased_candidates(kb, span): def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]):
return kb._get_alias_candidates(span.text.lower()) for mentions_for_doc in mentions:
yield [
def get_lowercased_candidates_batch(kb, spans): kb._get_alias_candidates(ent_span.text.lower())
return [get_lowercased_candidates(kb, span) for span in spans] for ent_span in mentions_for_doc
]
@registry.misc("spacy.LowercaseCandidateGenerator.v1") @registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> ( def create_candidates() -> Callable[
Callable[[InMemoryLookupKB, "Span"], Iterable[Candidate]] [InMemoryLookupKB, Iterator[SpanGroup]],
): Iterator[Iterable[Iterable[Candidate]]],
]:
return get_lowercased_candidates return get_lowercased_candidates
@registry.misc("spacy.LowercaseCandidateBatchGenerator.v1")
def create_candidates_batch() -> (
Callable[[InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[Candidate]]]
):
return get_lowercased_candidates_batch
# replace the pipe with a new one with with a different candidate generator # replace the pipe with a new one with with a different candidate generator
entity_linker = nlp.replace_pipe( entity_linker = nlp.replace_pipe(
"entity_linker", "entity_linker",
@ -515,9 +601,6 @@ def test_el_pipe_configuration(nlp):
config={ config={
"incl_context": False, "incl_context": False,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
"get_candidates_batch": {
"@misc": "spacy.LowercaseCandidateBatchGenerator.v1"
},
}, },
) )
entity_linker.set_kb(create_kb) entity_linker.set_kb(create_kb)

View File

@ -1255,6 +1255,15 @@ A function that reads an existing `KnowledgeBase` from file.
| --------- | -------------------------------------------------------- | | --------- | -------------------------------------------------------- |
| `kb_path` | The location of the KB that was stored to file. ~~Path~~ | | `kb_path` | The location of the KB that was stored to file. ~~Path~~ |
### spacy.CandidateGenerator.v2 {id="CandidateGenerator-v2"}
A function that takes as input a [`KnowledgeBase`](/api/kb) and a
`Iterator[SpanGroup]` object denoting a collection of named entities for
multiple [`Doc`](/api/doc), and returns an iterable of plausible
[`Candidate`](/api/kb/#candidate) objects per `Doc`. The default
`CandidateGenerator` uses the text of a mention to find its potential aliases in
the `KnowledgeBase`. Note that this function is case-dependent.
### spacy.CandidateGenerator.v1 {id="CandidateGenerator"} ### spacy.CandidateGenerator.v1 {id="CandidateGenerator"}
A function that takes as input a [`KnowledgeBase`](/api/kb) and a A function that takes as input a [`KnowledgeBase`](/api/kb) and a

View File

@ -47,29 +47,42 @@ architectures and their arguments and hyperparameters.
> "incl_context": True, > "incl_context": True,
> "model": DEFAULT_NEL_MODEL, > "model": DEFAULT_NEL_MODEL,
> "entity_vector_length": 64, > "entity_vector_length": 64,
> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v1'}, > "get_candidates": {'@misc': 'spacy.CandidateGenerator.v2'},
> "threshold": None, > "threshold": None,
> } > }
> nlp.add_pipe("entity_linker", config=config) > nlp.add_pipe("entity_linker", config=config)
> ``` > ```
| Setting | Description | | Setting | Description |
| --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ | | `labels_discard` | NER labels that will automatically get an "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ | | `n_sents` | The number of neighbouring sentences to take into account. Defaults to `0`. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ | | `incl_prior` | Whether prior probabilities from the KB are included in the model. Defaults to `True`. ~~bool~~ |
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ | | `incl_context` | Whether the local context is included in the model. Defaults to `True`. ~~bool~~ |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ | | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [`EntityLinker`](/api/architectures#EntityLinker). ~~Model~~ |
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ | | `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ | | `use_gold_ents` | Whether entities are copied from the gold docs. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `get_candidates` <Tag variant="new">4.0</Tag> | Function that retrieves plausible candidates per entity mention in a given `Iterator[SpanGroup]` (one `SpanGroup` includes all mentions found in a given `Doc` instance). Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator). ~~Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]~~ |
| `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ | | `generate_empty_kb` <Tag variant="new">3.6</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
| `generate_empty_kb` <Tag variant="new">3.5.1</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ | | `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ | | `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"ents"` and `"scores"`. ~~Union[bool, list[str]]~~ | | `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"ents"` and `"scores"`. ~~Union[bool, list[str]]~~ |
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ | | `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
<Infobox variant="warning">
Prior to spaCy v4.0 `get_candidates()` returns a single `Iterable` of candidates
for one specific mention, i. e. the function was typed as
`Callable[[KnowledgeBase, Span], Iterable[Candidate]]`. To retrieve candidates
batch-wise, spaCy >= 3.5 exposes `get_candidates_batched()`, which identifies
candidates for an arbitrary number of spans:
`Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]`. The
main difference between `get_candidates_batched()` and `get_candidates()` in
spaCy >= 4.0 is that the latter considers the grouping of provided mention spans
per `Doc` instance.
</Infobox>
```python ```python
%%GITHUB_SPACY/spacy/pipeline/entity_linker.py %%GITHUB_SPACY/spacy/pipeline/entity_linker.py
``` ```
@ -108,7 +121,7 @@ custom knowledge base, you should either call
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ | | `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | | | _keyword-only_ | |
| `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ | | `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | | `get_candidates` | Function that retrieves plausible candidates per entity mention in a given `SpanGroup`. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator). ~~Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]~~ |
| `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ | | `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. ~~int~~ | | `n_sents` | The number of neighbouring sentences to take into account. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ | | `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ |

View File

@ -155,35 +155,12 @@ Get a list of all aliases in the knowledge base.
## InMemoryLookupKB.get_candidates {id="get_candidates",tag="method"} ## InMemoryLookupKB.get_candidates {id="get_candidates",tag="method"}
Given a certain textual mention as input, retrieve a list of candidate entities Given textual mentions for an arbitrary number of documents as input, retrieve a
of type [`InMemoryCandidate`](/api/kb#candidate). Wraps list of candidate entities of type [`InMemoryCandidate`](/api/kb#candidate) for
[`get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates). each mention. The [`EntityLinker`](/api/entitylinker) component passes a
generator that yields mentions as [`SpanGroup`](/api/spangroup))s per document. The decision of how to batch
> #### Example candidate retrieval lookups over multiple documents is left up to the
> implementation of `KnowledgeBase.get_candidates()`.
> ```python
> from spacy.lang.en import English
> nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates(doc[0:2])
> ```
| Name | Description |
| ----------- | ------------------------------------------------------------------------------------ |
| `mention` | The textual mention or alias. ~~Span~~ |
| **RETURNS** | An iterable of relevant `InMemoryCandidate` objects. ~~Iterable[InMemoryCandidate]~~ |
## InMemoryLookupKB.get_candidates_batch {id="get_candidates_batch",tag="method"}
Same as [`get_candidates()`](/api/inmemorylookupkb#get_candidates), but for an
arbitrary number of mentions. The [`EntityLinker`](/api/entitylinker) component
will call `get_candidates_batch()` instead of `get_candidates()`, if the config
parameter `candidates_batch_size` is greater or equal than 1.
The default implementation of `get_candidates_batch()` executes
`get_candidates()` in a loop. We recommend implementing a more efficient way to
retrieve candidates for multiple mentions at once, if performance is of concern
to you.
> #### Example > #### Example
> >
@ -192,13 +169,13 @@ to you.
> from spacy.tokens import SpanGroup > from spacy.tokens import SpanGroup
> nlp = English() > nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.") > doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates_batch([SpanGroup(doc, spans=[doc[0:2], doc[3:]]]) > candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]])
> ``` > ```
| Name | Description | | Name | Description |
| ----------- | ------------------------------------------------------------------------------------------------------------ | | ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `mentions` | The textual mentions. ~~SpanGroup~~ | | `mentions` | The textual mentions or aliases (one `SpanGroup` per `Doc` instance). ~~Iterator[SpanGroup]~~ |
| **RETURNS** | An iterable of iterable with relevant `InMemoryCandidate` objects. ~~Iterable[Iterable[InMemoryCandidate]]~~ | | **RETURNS** | An iterator over iterables of iterables with relevant [`InMemoryCandidate`](/api/kb#candidate) objects (per mention and doc). ~~Iterator[Iterable[Iterable[InMemoryCandidate]]]~~ |
## InMemoryLookupKB.get_vector {id="get_vector",tag="method"} ## InMemoryLookupKB.get_vector {id="get_vector",tag="method"}

View File

@ -60,34 +60,13 @@ The length of the fixed-size entity vectors in the knowledge base.
## KnowledgeBase.get_candidates {id="get_candidates",tag="method"} ## KnowledgeBase.get_candidates {id="get_candidates",tag="method"}
Given a certain textual mention as input, retrieve a list of candidate entities Given textual mentions for an arbitrary number of documents as input, retrieve a
of type [`Candidate`](/api/kb#candidate). list of candidate entities of type [`Candidate`](/api/kb#candidate) for each
mention. The [`EntityLinker`](/api/entitylinker) component passes a generator
> #### Example that yields mentions as [`SpanGroup`](/api/spangroup))s per document.
> The decision of how to batch
> ```python candidate retrieval lookups over multiple documents is left up to the
> from spacy.lang.en import English implementation of `KnowledgeBase.get_candidates()`.
> nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates(doc[0:2])
> ```
| Name | Description |
| ----------- | -------------------------------------------------------------------- |
| `mention` | The textual mention or alias. ~~Span~~ |
| **RETURNS** | An iterable of relevant `Candidate` objects. ~~Iterable[Candidate]~~ |
## KnowledgeBase.get_candidates_batch {id="get_candidates_batch",tag="method"}
Same as [`get_candidates()`](/api/kb#get_candidates), but for an arbitrary
number of mentions. The [`EntityLinker`](/api/entitylinker) component will call
`get_candidates_batch()` instead of `get_candidates()`, if the config parameter
`candidates_batch_size` is greater or equal than 1.
The default implementation of `get_candidates_batch()` executes
`get_candidates()` in a loop. We recommend implementing a more efficient way to
retrieve candidates for multiple mentions at once, if performance is of concern
to you.
> #### Example > #### Example
> >
@ -100,9 +79,9 @@ to you.
> ``` > ```
| Name | Description | | Name | Description |
| ----------- | -------------------------------------------------------------------------------------------- | | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `mentions` | The textual mentions. ~~SpanGroup~~ | | `mentions` | The textual mentions or aliases (one `SpanGroup` per `Doc` instance). ~~Iterator[SpanGroup]~~ |
| **RETURNS** | An iterable of iterable with relevant `Candidate` objects. ~~Iterable[Iterable[Candidate]]~~ | | **RETURNS** | An iterator (per document) over iterables (per mention) of iterables (per candidate for this mention) with relevant `Candidate` objects. ~~Iterator[Iterable[Iterable[Candidate]]]~~ |
## KnowledgeBase.get_vector {id="get_vector",tag="method"} ## KnowledgeBase.get_vector {id="get_vector",tag="method"}