Convert batched into doc-wise batched candidate generation.

This commit is contained in:
Raphael Mitsch 2022-10-18 15:31:15 +02:00
parent 2ce6aadda2
commit 7c28424f47
7 changed files with 182 additions and 155 deletions

View File

@ -946,11 +946,10 @@ class Errors(metaclass=ErrorsWithCodes):
"case pass an empty list for the previously not specified argument to avoid this error.") "case pass an empty list for the previously not specified argument to avoid this error.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.") "{value}.")
E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}") E1044 = ("Encountered {parent} subclass without `{parent}.{method}` "
E1045 = ("Encountered {parent} subclass without `{parent}.{method}` "
"method in '{name}'. If you want to use this method, make " "method in '{name}'. If you want to use this method, make "
"sure it's overwritten on the subclass.") "sure it's overwritten on the subclass.")
E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default " E1045 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default "
"knowledge base, use `InMemoryLookupKB`.") "knowledge base, use `InMemoryLookupKB`.")

View File

@ -1,3 +1,3 @@
from .kb import KnowledgeBase from .kb import KnowledgeBase
from .kb_in_memory import InMemoryLookupKB from .kb_in_memory import InMemoryLookupKB
from .candidate import Candidate, get_candidates, get_candidates_batch from .candidate import Candidate, get_candidates, get_candidates_all

View File

@ -1,6 +1,6 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from typing import Iterable from typing import Iterable, Generator
from .kb cimport KnowledgeBase from .kb cimport KnowledgeBase
from ..tokens import Span from ..tokens import Span
@ -64,11 +64,13 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
return kb.get_candidates(mention) return kb.get_candidates(mention)
def get_candidates_batch(kb: KnowledgeBase, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: def get_candidates_all(
kb: KnowledgeBase, mentions: Generator[Iterable[Span], None, None]
) -> Generator[Iterable[Iterable[Candidate]], None, None]:
""" """
Return candidate entities for the given mentions and fetching appropriate entries from the index. Return candidate entities for the given mentions and fetching appropriate entries from the index.
kb (KnowledgeBase): Knowledge base to query. kb (KnowledgeBase): Knowledge base to query.
mention (Iterable[Span]): Entity mentions for which to identify candidates. mention (Generator[Iterable[Span]]): Entity mentions per document for which to identify candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
""" """
return kb.get_candidates_batch(mentions) return kb.get_candidates_all(mentions)

View File

@ -1,7 +1,7 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from pathlib import Path from pathlib import Path
from typing import Iterable, Tuple, Union from typing import Iterable, Tuple, Union, Generator
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from .candidate import Candidate from .candidate import Candidate
@ -23,22 +23,24 @@ cdef class KnowledgeBase:
# Make sure abstract KB is not instantiated. # Make sure abstract KB is not instantiated.
if self.__class__ == KnowledgeBase: if self.__class__ == KnowledgeBase:
raise TypeError( raise TypeError(
Errors.E1046.format(cls_name=self.__class__.__name__) Errors.E1045.format(cls_name=self.__class__.__name__)
) )
self.vocab = vocab self.vocab = vocab
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self.mem = Pool() self.mem = Pool()
def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: def get_candidates_all(self, mentions: Generator[Iterable[Span]]) -> Generator[Iterable[Iterable[Candidate]]]:
""" """
Return candidate entities for specified texts. Each candidate defines the entity, the original alias, Return candidate entities for specified texts. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity. and the prior probability of that alias resolving to that entity.
If no candidate is found for a given text, an empty list is returned. If no candidate is found for a given text, an empty list is returned.
mentions (Iterable[Span]): Mentions for which to get candidates. mentions (Generator[Iterable[Span]]): Mentions per documents for which to get candidates.
RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document.
""" """
return [self.get_candidates(span) for span in mentions]
for doc_mentions in mentions:
yield [self.get_candidates(span) for span in doc_mentions]
def get_candidates(self, mention: Span) -> Iterable[Candidate]: def get_candidates(self, mention: Span) -> Iterable[Candidate]:
""" """
@ -49,7 +51,7 @@ cdef class KnowledgeBase:
RETURNS (Iterable[Candidate]): Identified candidates. RETURNS (Iterable[Candidate]): Identified candidates.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__)
) )
def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]: def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]:
@ -67,7 +69,7 @@ cdef class KnowledgeBase:
RETURNS (Iterable[float]): Vector for specified entity. RETURNS (Iterable[float]): Vector for specified entity.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="get_vector", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="get_vector", name=self.__name__)
) )
def to_bytes(self, **kwargs) -> bytes: def to_bytes(self, **kwargs) -> bytes:
@ -75,7 +77,7 @@ cdef class KnowledgeBase:
RETURNS (bytes): Current state as binary string. RETURNS (bytes): Current state as binary string.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__)
) )
def from_bytes(self, bytes_data: bytes, *, exclude: Tuple[str] = tuple()): def from_bytes(self, bytes_data: bytes, *, exclude: Tuple[str] = tuple()):
@ -84,7 +86,7 @@ cdef class KnowledgeBase:
exclude (Tuple[str]): Properties to exclude when restoring KB. exclude (Tuple[str]): Properties to exclude when restoring KB.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__)
) )
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None:
@ -94,7 +96,7 @@ cdef class KnowledgeBase:
exclude (Iterable[str]): List of components to exclude. exclude (Iterable[str]): List of components to exclude.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="to_disk", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="to_disk", name=self.__name__)
) )
def from_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: def from_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None:
@ -104,5 +106,5 @@ cdef class KnowledgeBase:
exclude (Iterable[str]): List of components to exclude. exclude (Iterable[str]): List of components to exclude.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
) )

View File

@ -1,12 +1,12 @@
from pathlib import Path from pathlib import Path
from typing import Optional, Callable, Iterable, List, Tuple from typing import Optional, Callable, Iterable, List, Tuple, Generator
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.api import chain, list2ragged, reduce_mean, residual from thinc.api import chain, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear, tuplify, Ragged from thinc.api import Model, Maxout, Linear, tuplify, Ragged
from ...util import registry from ...util import registry
from ...kb import KnowledgeBase, InMemoryLookupKB from ...kb import KnowledgeBase, InMemoryLookupKB
from ...kb import Candidate, get_candidates, get_candidates_batch from ...kb import Candidate, get_candidates, get_candidates_all
from ...vocab import Vocab from ...vocab import Vocab
from ...tokens import Span, Doc from ...tokens import Span, Doc
from ..extract_spans import extract_spans from ..extract_spans import extract_spans
@ -105,8 +105,8 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates return get_candidates
@registry.misc("spacy.CandidateBatchGenerator.v1") @registry.misc("spacy.CandidateAllGenerator.v1")
def create_candidates_batch() -> Callable[ def create_candidates_all() -> Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] [KnowledgeBase, Generator[Iterable[Span], None, None]], Generator[Iterable[Iterable[Candidate]], None, None]
]: ]:
return get_candidates_batch return get_candidates_all

View File

@ -1,4 +1,4 @@
from typing import Optional, Iterable, Callable, Dict, Union, List, Any from typing import Optional, Iterable, Callable, Dict, Union, List, Any, Generator
from thinc.types import Floats2d from thinc.types import Floats2d
from pathlib import Path from pathlib import Path
from itertools import islice from itertools import islice
@ -53,11 +53,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"incl_context": True, "incl_context": True,
"entity_vector_length": 64, "entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, "get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"},
"overwrite": True, "overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True, "use_gold_ents": True,
"candidates_batch_size": 1, "candidates_doc_mode": False,
"threshold": None, "threshold": None,
}, },
default_score_weights={ default_score_weights={
@ -77,13 +77,14 @@ def make_entity_linker(
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[ get_candidates_all: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] [KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
], ],
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
use_gold_ents: bool, use_gold_ents: bool,
candidates_batch_size: int, candidates_doc_mode: bool,
threshold: Optional[float] = None, threshold: Optional[float] = None,
): ):
"""Construct an EntityLinker component. """Construct an EntityLinker component.
@ -98,13 +99,18 @@ def make_entity_linker(
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention. produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch ( get_candidates_all (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] Callable[
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. [KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation. candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator
yielding entities per document (candidate generator callable is called only once in this case). If False,
the candidate generator is called once per entity.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold,
prediction is discarded. If None, predictions are not filtered by any threshold. prediction is discarded. If None, predictions are not filtered by any threshold.
""" """
@ -134,11 +140,11 @@ def make_entity_linker(
incl_context=incl_context, incl_context=incl_context,
entity_vector_length=entity_vector_length, entity_vector_length=entity_vector_length,
get_candidates=get_candidates, get_candidates=get_candidates,
get_candidates_batch=get_candidates_batch, get_candidates_all=get_candidates_all,
overwrite=overwrite, overwrite=overwrite,
scorer=scorer, scorer=scorer,
use_gold_ents=use_gold_ents, use_gold_ents=use_gold_ents,
candidates_batch_size=candidates_batch_size, candidates_doc_mode=candidates_doc_mode,
threshold=threshold, threshold=threshold,
) )
@ -172,13 +178,14 @@ class EntityLinker(TrainablePipe):
incl_context: bool, incl_context: bool,
entity_vector_length: int, entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
get_candidates_batch: Callable[ get_candidates_all: Callable[
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] [KnowledgeBase, Generator[Iterable[Span], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
], ],
overwrite: bool = BACKWARD_OVERWRITE, overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool, use_gold_ents: bool,
candidates_batch_size: int, candidates_doc_mode: bool,
threshold: Optional[float] = None, threshold: Optional[float] = None,
) -> None: ) -> None:
"""Initialize an entity linker. """Initialize an entity linker.
@ -194,14 +201,18 @@ class EntityLinker(TrainablePipe):
entity_vector_length (int): Size of encoding vectors in the KB. entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention. produces a list of candidates, given a certain knowledge base and a textual mention.
get_candidates_batch ( get_candidates_all (
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Callable[
Iterable[Candidate]] [KnowledgeBase, Generator[Iterable[Span], None, None]],
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. Generator[Iterable[Iterable[Candidate]], None, None]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
candidates_batch_size (int): Size of batches for entity candidate generation. candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator
yielding entities per document (candidate generator callable is called only once in this case). If False,
the candidate generator is called once per entity.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
threshold, prediction is discarded. If None, predictions are not filtered by any threshold. threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
DOCS: https://spacy.io/api/entitylinker#init DOCS: https://spacy.io/api/entitylinker#init
@ -224,7 +235,7 @@ class EntityLinker(TrainablePipe):
self.incl_prior = incl_prior self.incl_prior = incl_prior
self.incl_context = incl_context self.incl_context = incl_context
self.get_candidates = get_candidates self.get_candidates = get_candidates
self.get_candidates_batch = get_candidates_batch self.get_candidates_all = get_candidates_all
self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neighbour sentences to take into account # how many neighbour sentences to take into account
@ -232,12 +243,9 @@ class EntityLinker(TrainablePipe):
self.kb = empty_kb(entity_vector_length)(self.vocab) self.kb = empty_kb(entity_vector_length)(self.vocab)
self.scorer = scorer self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.candidates_batch_size = candidates_batch_size self.candidates_doc_mode = candidates_doc_mode
self.threshold = threshold self.threshold = threshold
if candidates_batch_size < 1:
raise ValueError(Errors.E1044)
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
"""Define the KB of this pipe by providing a function that will """Define the KB of this pipe by providing a function that will
create it using this object's vocab.""" create it using this object's vocab."""
@ -440,35 +448,37 @@ class EntityLinker(TrainablePipe):
return final_kb_ids return final_kb_ids
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
for i, doc in enumerate(docs):
# Determine which entities are to be ignored due to labels_discard.
valid_ent_idx_per_doc = (
[
idx
for idx in range(len(doc.ents))
if doc.ents[idx].label_ not in self.labels_discard
]
for doc in docs if len(doc.ents)
)
# Call candidate generator.
if self.candidates_doc_mode:
all_ent_cands = self.get_candidates_all(
self.kb,
([doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] for doc in docs if len(doc.ents))
)
else:
# Alternative: collect entities the old-fashioned way - by retrieving entities individually.
all_ent_cands = (
[self.get_candidates(self.kb, doc.ents[idx]) for idx in next(valid_ent_idx_per_doc)]
for doc in docs if len(doc.ents)
)
for doc_idx, doc in enumerate(docs):
if len(doc) == 0: if len(doc) == 0:
continue continue
sentences = [s for s in doc.sents] sentences = [s for s in doc.sents]
doc_ent_cands = list(next(all_ent_cands)) if len(doc.ents) else []
# Loop over entities in batches. # Looping over candidate entities for this doc. (TODO: rewrite)
for ent_idx in range(0, len(doc.ents), self.candidates_batch_size): for ent_cand_idx, ent in enumerate(doc.ents):
ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size]
# Look up candidate entities.
valid_ent_idx = [
idx
for idx in range(len(ent_batch))
if ent_batch[idx].label_ not in self.labels_discard
]
batch_candidates = list(
self.get_candidates_batch(
self.kb, [ent_batch[idx] for idx in valid_ent_idx]
)
if self.candidates_batch_size > 1
else [
self.get_candidates(self.kb, ent_batch[idx])
for idx in valid_ent_idx
]
)
# Looping through each entity in batch (TODO: rewrite)
for j, ent in enumerate(ent_batch):
sent_index = sentences.index(ent.sent) sent_index = sentences.index(ent.sent)
assert sent_index >= 0 assert sent_index >= 0
@ -490,7 +500,7 @@ class EntityLinker(TrainablePipe):
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
else: else:
candidates = list(batch_candidates[j]) candidates = list(doc_ent_cands[ent_cand_idx])
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)

View File

@ -1,4 +1,4 @@
from typing import Callable, Iterable, Dict, Any from typing import Callable, Iterable, Dict, Any, Generator
import pytest import pytest
from numpy.testing import assert_equal from numpy.testing import assert_equal
@ -497,11 +497,14 @@ def test_el_pipe_configuration(nlp):
assert doc[1].ent_kb_id_ == "" assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2" assert doc[2].ent_kb_id_ == "Q2"
# Replace the pipe with a new one with with a different candidate generator.
def get_lowercased_candidates(kb, span): def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower()) return kb.get_alias_candidates(span.text.lower())
def get_lowercased_candidates_batch(kb, spans): def get_lowercased_candidates_all(kb, spans_per_doc):
return [get_lowercased_candidates(kb, span) for span in spans] for doc_spans in spans_per_doc:
yield [get_lowercased_candidates(kb, span) for span in doc_spans]
@registry.misc("spacy.LowercaseCandidateGenerator.v1") @registry.misc("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[ def create_candidates() -> Callable[
@ -509,29 +512,39 @@ def test_el_pipe_configuration(nlp):
]: ]:
return get_lowercased_candidates return get_lowercased_candidates
@registry.misc("spacy.LowercaseCandidateBatchGenerator.v1") @registry.misc("spacy.LowercaseCandidateAllGenerator.v1")
def create_candidates_batch() -> Callable[ def create_candidates_batch() -> Callable[
[InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[Candidate]] [InMemoryLookupKB, Generator[Iterable["Span"], None, None]],
Generator[Iterable[Iterable[Candidate]], None, None]
]: ]:
return get_lowercased_candidates_batch return get_lowercased_candidates_all
# replace the pipe with a new one with with a different candidate generator def test_reconfigured_el(candidates_doc_mode: bool, doc_text: str) -> None:
entity_linker = nlp.replace_pipe( """Test reconfigured EL for correct results.
candidates_doc_mode (bool): candidates_doc_mode in pipe config.
doc_text (str): Text to infer.
"""
_entity_linker = nlp.replace_pipe(
"entity_linker", "entity_linker",
"entity_linker", "entity_linker",
config={ config={
"incl_context": False, "incl_context": False,
"candidates_doc_mode": candidates_doc_mode,
"get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"},
"get_candidates_batch": { "get_candidates_all": {
"@misc": "spacy.LowercaseCandidateBatchGenerator.v1" "@misc": "spacy.LowercaseCandidateAllGenerator.v1"
}, },
}, },
) )
entity_linker.set_kb(create_kb) _entity_linker.set_kb(create_kb)
doc = nlp(text) _doc = nlp(doc_text)
assert doc[0].ent_kb_id_ == "Q2" assert _doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == "" assert _doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2" assert _doc[2].ent_kb_id_ == "Q2"
# Test individual and doc-wise candidate generation.
test_reconfigured_el(False, text)
test_reconfigured_el(True, text)
def test_nel_nsents(nlp): def test_nel_nsents(nlp):
@ -670,6 +683,7 @@ def test_preserving_links_asdoc(nlp):
assert s_ent.kb_id_ == orig_kb_id assert s_ent.kb_id_ == orig_kb_id
def test_preserving_links_ents(nlp): def test_preserving_links_ents(nlp):
"""Test that doc.ents preserves KB annotations""" """Test that doc.ents preserves KB annotations"""
text = "She lives in Boston. He lives in Denver." text = "She lives in Boston. He lives in Denver."