Modify EL batching to doc-wise streaming approach (#12367)

* Convert Candidate from Cython to Python class.

* Format.

* Fix .entity_ typo in _add_activations() usage.

* Change type for mentions to look up entity candidates for to SpanGroup from Iterable[Span].

* Update docs.

* Update spacy/kb/candidate.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update doc string of BaseCandidate.__init__().

* Update spacy/kb/candidate.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Rename Candidate to InMemoryCandidate, BaseCandidate to Candidate.

* Adjust Candidate to support and mandate numerical entity IDs.

* Format.

* Fix docstring and docs.

* Update website/docs/api/kb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Rename alias -> mention.

* Refactor Candidate attribute names. Update docs and tests accordingly.

* Refacor Candidate attributes and their usage.

* Format.

* Fix mypy error.

* Update error code in line with v4 convention.

* Modify EL batching system.

* Update leftover get_candidates() mention in docs.

* Format docs.

* Format.

* Update spacy/kb/candidate.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Updated error code.

* Simplify interface for int/str representations.

* Update website/docs/api/kb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Rename 'alias' to 'mention'.

* Port Candidate and InMemoryCandidate to Cython.

* Remove redundant entry in setup.py.

* Add abstract class check.

* Drop storing mention.

* Update spacy/kb/candidate.pxd

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Fix entity_id refactoring problems in docstrings.

* Drop unused InMemoryCandidate._entity_hash.

* Update docstrings.

* Move attributes out of Candidate.

* Partially fix alias/mention terminology usage. Convert Candidate to interface.

* Remove prior_prob from supported properties in Candidate. Introduce KnowledgeBase.supports_prior_probs().

* Update docstrings related to prior_prob.

* Update alias/mention usage in doc(strings).

* Update spacy/ml/models/entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/ml/models/entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Mention -> alias renaming. Drop Candidate.mentions(). Drop InMemoryLookupKB.get_alias_candidates() from docs.

* Update docstrings.

* Fix InMemoryCandidate attribute names.

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/ml/models/entity_linker.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update W401 test.

* Update spacy/errors.py

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Use Candidate output type for toy generators in the test suite to mimick best practices

* fix docs

* fix import

* Fix merge leftovers.

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/kb/kb.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/kb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/entitylinker.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update spacy/kb/kb_in_memory.pyx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update website/docs/api/inmemorylookupkb.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update get_candidates() docstring.

* Reformat imports in entity_linker.py.

* Drop valid_ent_idx_per_doc.

* Update docs.

* Format.

* Simplify doc loop in predict().

* Remove E1044 comment.

* Fix merge errors.

* Format.

* Format.

* Format.

* Fix merge error & tests.

* Format.

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Use type alias.

* isort.

* isort.

* Lint.

* Add typedefs.pyx.

* Fix typedef import.

* Fix type aliases.

* Format.

* Update docstring and type usage.

* Add info on get_candidates(), get_candidates_batched().

* Readd get_candidates info to v3 changelog.

* Update website/docs/api/entitylinker.mdx

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>

* Update factory functions for backwards compatibility.

* Format.

* Ignore mypy error.

* Fix mypy error.

* Format.

* Add test for multiple docs with multiple entities.

---------

Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
Co-authored-by: svlandeg <svlandeg@github.com>
This commit is contained in:
Raphael Mitsch 2024-04-09 11:39:18 +02:00 committed by GitHub
parent afb22ad491
commit 304b9331e6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 344 additions and 296 deletions

View File

@ -238,7 +238,7 @@ grad_factor = 1.0
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
get_candidates = {"@misc":"spacy.CandidateGenerator.v2"}
incl_context = true
incl_prior = true
@ -517,7 +517,7 @@ width = ${components.tok2vec.model.encode.width}
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
get_candidates = {"@misc":"spacy.CandidateGenerator.v2"}
incl_context = true
incl_prior = true

View File

@ -950,7 +950,6 @@ class Errors(metaclass=ErrorsWithCodes):
"case pass an empty list for the previously not specified argument to avoid this error.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.")
E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}")
E1045 = ("Encountered {parent} subclass without `{parent}.{method}` "
"method in '{name}'. If you want to use this method, make "
"sure it's overwritten on the subclass.")

View File

@ -1,14 +1,14 @@
# cython: infer_types=True
from pathlib import Path
from typing import Iterable, Tuple, Union
from typing import Iterable, Iterator, Tuple, Union
from cymem.cymem cimport Pool
from ..errors import Errors
from ..tokens import Span, SpanGroup
from ..tokens import SpanGroup
from ..util import SimpleFrozenList
from .candidate import Candidate
from .candidate cimport Candidate
cdef class KnowledgeBase:
@ -19,6 +19,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb
"""
CandidatesForMentionT = Iterable[Candidate]
CandidatesForDocT = Iterable[CandidatesForMentionT]
def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase."""
@ -32,27 +34,15 @@ cdef class KnowledgeBase:
self.entity_vector_length = entity_vector_length
self.mem = Pool()
def get_candidates_batch(
self, mentions: SpanGroup
) -> Iterable[Iterable[Candidate]]:
def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[CandidatesForDocT]:
"""
Return candidate entities for a specified Span mention. Each candidate defines at least the entity and the
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included.
Return candidate entities for the specified groups of mentions (as SpanGroup) per Doc.
Each candidate for a mention defines at least the entity and the entity's embedding vector. Depending on the KB
implementation, further properties - such as the prior probability of the specified mention text resolving to
that entity - might be included.
If no candidates are found for a given mention, an empty list is returned.
mentions (SpanGroup): Mentions for which to get candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
"""
return [self.get_candidates(span) for span in mentions]
def get_candidates(self, mention: Span) -> Iterable[Candidate]:
"""
Return candidate entities for a specific mention. Each candidate defines at least the entity and the
entity's embedding vector. Depending on the KB implementation, further properties - such as the prior
probability of the specified mention text resolving to that entity - might be included.
If no candidate is found for the given mention, an empty list is returned.
mention (Span): Mention for which to get candidates.
RETURNS (Iterable[Candidate]): Identified candidates.
mentions (Iterator[SpanGroup]): Mentions for which to get candidates.
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per mention/doc/doc batch.
"""
raise NotImplementedError(
Errors.E1045.format(

View File

@ -1,5 +1,5 @@
# cython: infer_types=True
from typing import Any, Callable, Dict, Iterable
from typing import Any, Callable, Dict, Iterable, Iterator
import srsly
@ -12,7 +12,7 @@ from preshed.maps cimport PreshMap
import warnings
from pathlib import Path
from ..tokens import Span
from ..tokens import SpanGroup
from ..typedefs cimport hash_t
@ -255,8 +255,9 @@ cdef class InMemoryLookupKB(KnowledgeBase):
alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry
def get_candidates(self, mention: Span) -> Iterable[InMemoryCandidate]:
return self._get_alias_candidates(mention.text) # type: ignore
def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[InMemoryCandidate]]]:
for mentions_for_doc in mentions:
yield [self._get_alias_candidates(span.text) for span in mentions_for_doc]
def _get_alias_candidates(self, str alias) -> Iterable[InMemoryCandidate]:
"""

View File

@ -1,5 +1,5 @@
from pathlib import Path
from typing import Callable, Iterable, List, Optional, Tuple
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
from thinc.api import (
Linear,
@ -21,6 +21,9 @@ from ...util import registry
from ...vocab import Vocab
from ..extract_spans import extract_spans
CandidatesForMentionT = Iterable[Candidate]
CandidatesForDocT = Iterable[CandidatesForMentionT]
@registry.architectures("spacy.EntityLinker.v2")
def build_nel_encoder(
@ -117,34 +120,38 @@ def empty_kb(
@registry.misc("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
def create_get_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates
@registry.misc("spacy.CandidateBatchGenerator.v1")
def create_candidates_batch() -> Callable[
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
@registry.misc("spacy.CandidateGenerator.v2")
def create_get_candidates_v2() -> Callable[
[KnowledgeBase, Iterator[SpanGroup]], Iterator[CandidatesForDocT]
]:
return get_candidates_batch
return get_candidates_v2
def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
"""
Return candidate entities for a given mention and fetching appropriate entries from the index.
Return candidate entities for the given mention from the KB.
kb (KnowledgeBase): Knowledge base to query.
mention (Span): Entity mention for which to identify candidates.
RETURNS (Iterable[Candidate]): Identified candidates.
mention (Span): Entity mention.
RETURNS (Iterable[Candidate]): Identified candidates for specified mention.
"""
return kb.get_candidates(mention)
cands_per_doc = next(
get_candidates_v2(kb, iter([SpanGroup(mention.doc, spans=[mention])]))
)
assert isinstance(cands_per_doc, list)
return next(cands_per_doc[0])
def get_candidates_batch(
kb: KnowledgeBase, mentions: SpanGroup
) -> Iterable[Iterable[Candidate]]:
def get_candidates_v2(
kb: KnowledgeBase, mentions: Iterator[SpanGroup]
) -> Iterator[Iterable[Iterable[Candidate]]]:
"""
Return candidate entities for the given mentions and fetching appropriate entries from the index.
Return candidate entities for the given mentions from the KB.
kb (KnowledgeBase): Knowledge base to query.
mentions (SpanGroup): Entity mentions for which to identify candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates.
mentions (Iterator[SpanGroup]): Mentions per doc.
RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per mentions in document/SpanGroup.
"""
return kb.get_candidates_batch(mentions)
return kb.get_candidates(mentions)

View File

@ -1,8 +1,19 @@
import random
import warnings
from itertools import islice
from itertools import islice, tee
from pathlib import Path
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Union, cast
from typing import (
Any,
Callable,
Dict,
Iterable,
Iterator,
List,
Optional,
Sequence,
Union,
cast,
)
import srsly
from numpy import dtype
@ -54,13 +65,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"incl_prior": True,
"incl_context": True,
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"},
"get_candidates": {"@misc": "spacy.CandidateGenerator.v2"},
"overwrite": False,
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True,
"candidates_batch_size": 1,
"threshold": None,
"save_activations": False,
},
@ -80,15 +89,13 @@ def make_entity_linker(
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
get_candidates: Callable[
[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
scorer: Optional[Callable],
use_gold_ents: bool,
candidates_batch_size: int,
threshold: Optional[float] = None,
save_activations: bool,
):
@ -102,16 +109,13 @@ def make_entity_linker(
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch (
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
Function producing a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
prediction is discarded. If None, predictions are not filtered by any threshold.
save_activations (bool): save model activations in Doc when annotating.
@ -129,12 +133,10 @@ def make_entity_linker(
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch,
generate_empty_kb=generate_empty_kb,
overwrite=overwrite,
scorer=scorer,
use_gold_ents=use_gold_ents,
candidates_batch_size=candidates_batch_size,
threshold=threshold,
save_activations=save_activations,
)
@ -168,15 +170,14 @@ class EntityLinker(TrainablePipe):
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[
[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]
get_candidates: Callable[
[KnowledgeBase, Iterator[SpanGroup]],
Iterator[Iterable[Iterable[Candidate]]],
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = False,
scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool,
candidates_batch_size: int,
threshold: Optional[float] = None,
save_activations: bool = False,
) -> None:
@ -191,18 +192,14 @@ class EntityLinker(TrainablePipe):
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch (
Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]],
Iterable[Candidate]]
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]):
Function producing a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
overwrite (bool): Whether to overwrite existing non-empty annotations.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
save_activations (bool): save model activations in Doc when annotating.
@ -227,18 +224,14 @@ class EntityLinker(TrainablePipe):
self.incl_prior = incl_prior
self.incl_context = incl_context
self.get_candidates = get_candidates
self.get_candidates_batch = get_candidates_batch
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False)
self.kb = generate_empty_kb(self.vocab, entity_vector_length)
self.scorer = scorer
self.use_gold_ents = use_gold_ents
self.candidates_batch_size = candidates_batch_size
self.threshold = threshold
self.save_activations = save_activations
if candidates_batch_size < 1:
raise ValueError(Errors.E1044)
if self.incl_prior and not self.kb.supports_prior_probs:
warnings.warn(Warnings.W401)
@ -318,11 +311,12 @@ class EntityLinker(TrainablePipe):
If one isn't present, then the update step needs to be skipped.
"""
for eg in examples:
for ent in eg.predicted.ents:
candidates = list(self.get_candidates(self.kb, ent))
if candidates:
for candidates_for_doc in self.get_candidates(
self.kb,
(SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples),
):
for candidates_for_mention in candidates_for_doc:
if list(candidates_for_mention):
return True
return False
@ -451,40 +445,35 @@ class EntityLinker(TrainablePipe):
}
if isinstance(docs, Doc):
docs = [docs]
for doc in docs:
docs_iters = tee(docs, 2)
# Call candidate generator.
all_ent_cands = self.get_candidates(
self.kb,
(
SpanGroup(
doc,
spans=[
ent for ent in doc.ents if ent.label_ not in self.labels_discard
],
)
for doc in docs_iters[0]
),
)
for doc in docs_iters[1]:
doc_ents: List[Ints1d] = []
doc_scores: List[Floats1d] = []
if len(doc) == 0:
if len(doc) == 0 or len(doc.ents) == 0:
docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0)))
docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0)))
continue
sentences = [s for s in doc.sents]
doc_ent_cands = list(next(all_ent_cands))
# Loop over entities in batches.
for ent_idx in range(0, len(doc.ents), self.candidates_batch_size):
ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size]
# Look up candidate entities.
valid_ent_idx = [
idx
for idx in range(len(ent_batch))
if ent_batch[idx].label_ not in self.labels_discard
]
batch_candidates = list(
self.get_candidates_batch(
self.kb,
SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]),
)
if self.candidates_batch_size > 1
else [
self.get_candidates(self.kb, ent_batch[idx])
for idx in valid_ent_idx
]
)
# Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch):
# Looping over candidate entities for this doc. (TODO: rewrite)
for ent_cand_idx, ent in enumerate(doc.ents):
assert hasattr(ent, "sents")
sents = list(ent.sents)
sent_indices = (
@ -502,7 +491,6 @@ class EntityLinker(TrainablePipe):
start_token = sentences[start_sentence].start
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0]
sentence_encoding_t = sentence_encoding.T
@ -518,7 +506,7 @@ class EntityLinker(TrainablePipe):
ents=[0],
)
else:
candidates = list(batch_candidates[j])
candidates = list(doc_ent_cands[ent_cand_idx])
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
@ -540,11 +528,12 @@ class EntityLinker(TrainablePipe):
else:
random.shuffle(candidates)
# set all prior probabilities to 0 if incl_prior=False
if self.incl_prior and self.kb.supports_prior_probs:
prior_probs = xp.asarray([c.prior_prob for c in candidates]) # type: ignore
else:
prior_probs = xp.asarray([0.0 for _ in candidates])
scores = prior_probs
scores = prior_probs = xp.asarray(
[
c.prior_prob if self.incl_prior else 0.0
for c in candidates
]
)
# add in similarity from the context
if self.incl_context:
entity_encodings = xp.asarray(
@ -567,8 +556,7 @@ class EntityLinker(TrainablePipe):
scores = prior_probs + sims - (prior_probs * sims)
final_kb_ids.append(
candidates[scores.argmax().item()].entity_id_
if self.threshold is None
or scores.max() >= self.threshold
if self.threshold is None or scores.max() >= self.threshold
else EntityLinker.NIL
)
self._add_activations(
@ -577,6 +565,7 @@ class EntityLinker(TrainablePipe):
scores=scores,
ents=[c.entity_id for c in candidates],
)
self._add_doc_activations(
docs_scores=docs_scores,
docs_ents=docs_ents,
@ -588,6 +577,7 @@ class EntityLinker(TrainablePipe):
method="predict", msg="result variables not of equal length"
)
raise RuntimeError(err)
return {
KNOWLEDGE_BASE_IDS: final_kb_ids,
"ents": docs_ents,

View File

@ -1,21 +1,21 @@
from typing import Any, Callable, Dict, Iterable, cast
from typing import Any, Callable, Dict, Iterable, Iterator, cast
import pytest
from numpy.testing import assert_equal
from thinc.types import Ragged
from spacy import Language, registry, util
from spacy import registry, util
from spacy.attrs import ENT_KB_ID
from spacy.compat import pickle
from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase
from spacy.lang.en import English
from spacy.ml import load_kb
from spacy.ml.models.entity_linker import build_span_maker, get_candidates
from spacy.ml.models.entity_linker import build_span_maker, get_candidates_v2
from spacy.pipeline import EntityLinker, TrainablePipe
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer
from spacy.tests.util import make_tempdir
from spacy.tokens import Doc, Span
from spacy.tokens import Doc, Span, SpanGroup
from spacy.training import Example
from spacy.util import ensure_path
from spacy.vocab import Vocab
@ -453,11 +453,21 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
adam_ent_cands = get_candidates(mykb, adam_ent)
assert len(get_candidates(mykb, douglas_ent)) == 2
adam_ent_cands = next(
get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[adam_ent]))
)[0]
assert len(adam_ent_cands) == 1
assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
assert len(get_candidates(mykb, shrubbery_ent)) == 0
assert (
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0])
== 2
)
assert (
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0
) # default case sensitive
assert (
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0])
== 0
)
# test the content of the candidates
assert adam_ent_cands[0].entity_id_ == "Q2"
@ -466,6 +476,86 @@ def test_candidate_generation(nlp):
assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9)
def test_candidate_generation_multiple_docs(nlp):
"""Test correct candidate generation with multiple docs."""
mykb = InMemoryLookupKB(nlp.vocab, entity_vector_length=1)
docs = [nlp("douglas adam Adam shrubbery"), nlp("shrubbery Adam douglas adam")]
douglas_ents = [docs[0][0:1], docs[1][2:3]]
adam_ents = [docs[0][1:2], docs[1][3:4]]
Adam_ents = [docs[0][2:3], docs[1][1:2]]
shrubbery_ents = [docs[0][3:4], docs[1][0:1]]
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
adam_ent_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[adam_ents[0]]),
SpanGroup(doc=docs[1], spans=[adam_ents[1]]),
],
)
)
assert len(adam_ent_cands) == 2
assert (
len(
list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[douglas_ents[0]]),
SpanGroup(doc=docs[1], spans=[douglas_ents[1]]),
],
)
)
)
== 2
)
Adam_ent_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[Adam_ents[0]]),
SpanGroup(doc=docs[1], spans=[Adam_ents[1]]),
],
)
)
assert len(Adam_ent_cands) == 2
assert (
len(Adam_ent_cands[0][0]) == 0 and len(Adam_ent_cands[1][0]) == 0
) # default case sensitive
shrubbery_ents_cands = list(
get_candidates_v2(
mykb,
[
SpanGroup(doc=docs[0], spans=[shrubbery_ents[0]]),
SpanGroup(doc=docs[1], spans=[shrubbery_ents[1]]),
],
)
)
assert len(shrubbery_ents_cands) == 2
assert len(shrubbery_ents_cands[0][0]) == 0 and len(shrubbery_ents_cands[1][0]) == 0
# test the content of the candidates
assert (
adam_ent_cands[0][0][0].entity_id_ == adam_ent_cands[1][0][0].entity_id_ == "Q2"
)
assert adam_ent_cands[0][0][0].alias == adam_ent_cands[1][0][0].alias == "adam"
assert_almost_equal(adam_ent_cands[0][0][0].entity_freq, 12)
assert_almost_equal(adam_ent_cands[1][0][0].entity_freq, 12)
assert_almost_equal(adam_ent_cands[0][0][0].prior_prob, 0.9)
assert_almost_equal(adam_ent_cands[1][0][0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp):
"""Test correct candidate generation as part of the EL pipe"""
nlp.add_pipe("sentencizer")
@ -490,24 +580,20 @@ def test_el_pipe_configuration(nlp):
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
def get_lowercased_candidates(kb, span):
return kb._get_alias_candidates(span.text.lower())
def get_lowercased_candidates_batch(kb, spans):
return [get_lowercased_candidates(kb, span) for span in spans]
def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]):
for mentions_for_doc in mentions:
yield [
kb._get_alias_candidates(ent_span.text.lower())
for ent_span in mentions_for_doc
]
@registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> (
Callable[[InMemoryLookupKB, "Span"], Iterable[Candidate]]
):
def create_candidates() -> Callable[
[InMemoryLookupKB, Iterator[SpanGroup]],
Iterator[Iterable[Iterable[Candidate]]],
]:
return get_lowercased_candidates
@registry.misc("spacy.LowercaseCandidateBatchGenerator.v1")
def create_candidates_batch() -> (
Callable[[InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[Candidate]]]
):
return get_lowercased_candidates_batch
# replace the pipe with a new one with with a different candidate generator
entity_linker = nlp.replace_pipe(
"entity_linker",
@ -515,9 +601,6 @@ def test_el_pipe_configuration(nlp):
config={
"incl_context": False,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
"get_candidates_batch": {
"@misc": "spacy.LowercaseCandidateBatchGenerator.v1"
},
},
)
entity_linker.set_kb(create_kb)

View File

@ -1255,6 +1255,15 @@ A function that reads an existing `KnowledgeBase` from file.
| --------- | -------------------------------------------------------- |
| `kb_path` | The location of the KB that was stored to file. ~~Path~~ |
### spacy.CandidateGenerator.v2 {id="CandidateGenerator-v2"}
A function that takes as input a [`KnowledgeBase`](/api/kb) and a
`Iterator[SpanGroup]` object denoting a collection of named entities for
multiple [`Doc`](/api/doc), and returns an iterable of plausible
[`Candidate`](/api/kb/#candidate) objects per `Doc`. The default
`CandidateGenerator` uses the text of a mention to find its potential aliases in
the `KnowledgeBase`. Note that this function is case-dependent.
### spacy.CandidateGenerator.v1 {id="CandidateGenerator"}
A function that takes as input a [`KnowledgeBase`](/api/kb) and a

View File

@ -47,29 +47,42 @@ architectures and their arguments and hyperparameters.
> "incl_context": True,
> "model": DEFAULT_NEL_MODEL,
> "entity_vector_length": 64,
> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v1'},
> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v2'},
> "threshold": None,
> }
> nlp.add_pipe("entity_linker", config=config)
> ```
| Setting | Description |
| --------------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels_discard` | NER labels that will automatically get a "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to 0. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. Defaults to `True`. ~~bool~~ |
| `incl_context` | Whether or not to include the local context in the model. Defaults to `True`. ~~bool~~ |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ |
| ------------------------------------------------ | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `labels_discard` | NER labels that will automatically get an "NIL" prediction. Defaults to `[]`. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. Defaults to `0`. ~~int~~ |
| `incl_prior` | Whether prior probabilities from the KB are included in the model. Defaults to `True`. ~~bool~~ |
| `incl_context` | Whether the local context is included in the model. Defaults to `True`. ~~bool~~ |
| `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [`EntityLinker`](/api/architectures#EntityLinker). ~~Model~~ |
| `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ |
| `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
| `get_candidates_batch` <Tag variant="new">3.5</Tag> | Function that generates plausible candidates for a given batch of `Span` objects. Defaults to [CandidateBatchGenerator](/api/architectures#CandidateBatchGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]~~ |
| `generate_empty_kb` <Tag variant="new">3.5.1</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
| `use_gold_ents` | Whether entities are copied from the gold docs. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ |
| `get_candidates` <Tag variant="new">4.0</Tag> | Function that retrieves plausible candidates per entity mention in a given `Iterator[SpanGroup]` (one `SpanGroup` includes all mentions found in a given `Doc` instance). Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator). ~~Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]~~ |
| `generate_empty_kb` <Tag variant="new">3.6</Tag> | Function that generates an empty `KnowledgeBase` object. Defaults to [`spacy.EmptyKB.v2`](/api/architectures#EmptyKB), which generates an empty [`InMemoryLookupKB`](/api/inmemorylookupkb). ~~Callable[[Vocab, int], KnowledgeBase]~~ |
| `overwrite` <Tag variant="new">3.2</Tag> | Whether existing annotation is overwritten. Defaults to `True`. ~~bool~~ |
| `scorer` <Tag variant="new">3.2</Tag> | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ |
| `save_activations` <Tag variant="new">4.0</Tag> | Save activations in `Doc` when annotating. Saved activations are `"ents"` and `"scores"`. ~~Union[bool, list[str]]~~ |
| `threshold` <Tag variant="new">3.4</Tag> | Confidence threshold for entity predictions. The default of `None` implies that all predictions are accepted, otherwise those with a score beneath the treshold are discarded. If there are no predictions with scores above the threshold, the linked entity is `NIL`. ~~Optional[float]~~ |
<Infobox variant="warning">
Prior to spaCy v4.0 `get_candidates()` returns a single `Iterable` of candidates
for one specific mention, i. e. the function was typed as
`Callable[[KnowledgeBase, Span], Iterable[Candidate]]`. To retrieve candidates
batch-wise, spaCy >= 3.5 exposes `get_candidates_batched()`, which identifies
candidates for an arbitrary number of spans:
`Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]]`. The
main difference between `get_candidates_batched()` and `get_candidates()` in
spaCy >= 4.0 is that the latter considers the grouping of provided mention spans
per `Doc` instance.
</Infobox>
```python
%%GITHUB_SPACY/spacy/pipeline/entity_linker.py
```
@ -108,7 +121,7 @@ custom knowledge base, you should either call
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
| _keyword-only_ | |
| `entity_vector_length` | Size of encoding vectors in the KB. ~~int~~ |
| `get_candidates` | Function that generates plausible candidates for a given `Span` object. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ |
| `get_candidates` | Function that retrieves plausible candidates per entity mention in a given `SpanGroup`. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator). ~~Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]~~ |
| `labels_discard` | NER labels that will automatically get a `"NIL"` prediction. ~~Iterable[str]~~ |
| `n_sents` | The number of neighbouring sentences to take into account. ~~int~~ |
| `incl_prior` | Whether or not to include prior probabilities from the KB in the model. ~~bool~~ |

View File

@ -155,35 +155,12 @@ Get a list of all aliases in the knowledge base.
## InMemoryLookupKB.get_candidates {id="get_candidates",tag="method"}
Given a certain textual mention as input, retrieve a list of candidate entities
of type [`InMemoryCandidate`](/api/kb#candidate). Wraps
[`get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates).
> #### Example
>
> ```python
> from spacy.lang.en import English
> nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates(doc[0:2])
> ```
| Name | Description |
| ----------- | ------------------------------------------------------------------------------------ |
| `mention` | The textual mention or alias. ~~Span~~ |
| **RETURNS** | An iterable of relevant `InMemoryCandidate` objects. ~~Iterable[InMemoryCandidate]~~ |
## InMemoryLookupKB.get_candidates_batch {id="get_candidates_batch",tag="method"}
Same as [`get_candidates()`](/api/inmemorylookupkb#get_candidates), but for an
arbitrary number of mentions. The [`EntityLinker`](/api/entitylinker) component
will call `get_candidates_batch()` instead of `get_candidates()`, if the config
parameter `candidates_batch_size` is greater or equal than 1.
The default implementation of `get_candidates_batch()` executes
`get_candidates()` in a loop. We recommend implementing a more efficient way to
retrieve candidates for multiple mentions at once, if performance is of concern
to you.
Given textual mentions for an arbitrary number of documents as input, retrieve a
list of candidate entities of type [`InMemoryCandidate`](/api/kb#candidate) for
each mention. The [`EntityLinker`](/api/entitylinker) component passes a
generator that yields mentions as [`SpanGroup`](/api/spangroup))s per document. The decision of how to batch
candidate retrieval lookups over multiple documents is left up to the
implementation of `KnowledgeBase.get_candidates()`.
> #### Example
>
@ -192,13 +169,13 @@ to you.
> from spacy.tokens import SpanGroup
> nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates_batch([SpanGroup(doc, spans=[doc[0:2], doc[3:]]])
> candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]])
> ```
| Name | Description |
| ----------- | ------------------------------------------------------------------------------------------------------------ |
| `mentions` | The textual mentions. ~~SpanGroup~~ |
| **RETURNS** | An iterable of iterable with relevant `InMemoryCandidate` objects. ~~Iterable[Iterable[InMemoryCandidate]]~~ |
| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `mentions` | The textual mentions or aliases (one `SpanGroup` per `Doc` instance). ~~Iterator[SpanGroup]~~ |
| **RETURNS** | An iterator over iterables of iterables with relevant [`InMemoryCandidate`](/api/kb#candidate) objects (per mention and doc). ~~Iterator[Iterable[Iterable[InMemoryCandidate]]]~~ |
## InMemoryLookupKB.get_vector {id="get_vector",tag="method"}

View File

@ -60,34 +60,13 @@ The length of the fixed-size entity vectors in the knowledge base.
## KnowledgeBase.get_candidates {id="get_candidates",tag="method"}
Given a certain textual mention as input, retrieve a list of candidate entities
of type [`Candidate`](/api/kb#candidate).
> #### Example
>
> ```python
> from spacy.lang.en import English
> nlp = English()
> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.")
> candidates = kb.get_candidates(doc[0:2])
> ```
| Name | Description |
| ----------- | -------------------------------------------------------------------- |
| `mention` | The textual mention or alias. ~~Span~~ |
| **RETURNS** | An iterable of relevant `Candidate` objects. ~~Iterable[Candidate]~~ |
## KnowledgeBase.get_candidates_batch {id="get_candidates_batch",tag="method"}
Same as [`get_candidates()`](/api/kb#get_candidates), but for an arbitrary
number of mentions. The [`EntityLinker`](/api/entitylinker) component will call
`get_candidates_batch()` instead of `get_candidates()`, if the config parameter
`candidates_batch_size` is greater or equal than 1.
The default implementation of `get_candidates_batch()` executes
`get_candidates()` in a loop. We recommend implementing a more efficient way to
retrieve candidates for multiple mentions at once, if performance is of concern
to you.
Given textual mentions for an arbitrary number of documents as input, retrieve a
list of candidate entities of type [`Candidate`](/api/kb#candidate) for each
mention. The [`EntityLinker`](/api/entitylinker) component passes a generator
that yields mentions as [`SpanGroup`](/api/spangroup))s per document.
The decision of how to batch
candidate retrieval lookups over multiple documents is left up to the
implementation of `KnowledgeBase.get_candidates()`.
> #### Example
>
@ -100,9 +79,9 @@ to you.
> ```
| Name | Description |
| ----------- | -------------------------------------------------------------------------------------------- |
| `mentions` | The textual mentions. ~~SpanGroup~~ |
| **RETURNS** | An iterable of iterable with relevant `Candidate` objects. ~~Iterable[Iterable[Candidate]]~~ |
| ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
| `mentions` | The textual mentions or aliases (one `SpanGroup` per `Doc` instance). ~~Iterator[SpanGroup]~~ |
| **RETURNS** | An iterator (per document) over iterables (per mention) of iterables (per candidate for this mention) with relevant `Candidate` objects. ~~Iterator[Iterable[Iterable[Candidate]]]~~ |
## KnowledgeBase.get_vector {id="get_vector",tag="method"}