From 7c28424f478c14f5e1dac523ae57ee6d4b207835 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Tue, 18 Oct 2022 15:31:15 +0200 Subject: [PATCH] Convert batched into doc-wise batched candidate generation. --- spacy/errors.py | 5 +- spacy/kb/__init__.py | 2 +- spacy/kb/candidate.pyx | 12 +- spacy/kb/kb.pyx | 26 +-- spacy/ml/models/entity_linker.py | 12 +- spacy/pipeline/entity_linker.py | 222 +++++++++++---------- spacy/tests/pipeline/test_entity_linker.py | 58 ++++-- 7 files changed, 182 insertions(+), 155 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index e0628819d..958859569 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -946,11 +946,10 @@ class Errors(metaclass=ErrorsWithCodes): "case pass an empty list for the previously not specified argument to avoid this error.") E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " "{value}.") - E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}") - E1045 = ("Encountered {parent} subclass without `{parent}.{method}` " + E1044 = ("Encountered {parent} subclass without `{parent}.{method}` " "method in '{name}'. If you want to use this method, make " "sure it's overwritten on the subclass.") - E1046 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default " + E1045 = ("{cls_name} is an abstract class and cannot be instantiated. If you are looking for spaCy's default " "knowledge base, use `InMemoryLookupKB`.") diff --git a/spacy/kb/__init__.py b/spacy/kb/__init__.py index 1d70a9b34..b61cb5447 100644 --- a/spacy/kb/__init__.py +++ b/spacy/kb/__init__.py @@ -1,3 +1,3 @@ from .kb import KnowledgeBase from .kb_in_memory import InMemoryLookupKB -from .candidate import Candidate, get_candidates, get_candidates_batch +from .candidate import Candidate, get_candidates, get_candidates_all diff --git a/spacy/kb/candidate.pyx b/spacy/kb/candidate.pyx index c89efeb03..5ad52618a 100644 --- a/spacy/kb/candidate.pyx +++ b/spacy/kb/candidate.pyx @@ -1,6 +1,6 @@ # cython: infer_types=True, profile=True -from typing import Iterable +from typing import Iterable, Generator from .kb cimport KnowledgeBase from ..tokens import Span @@ -64,11 +64,13 @@ def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: return kb.get_candidates(mention) -def get_candidates_batch(kb: KnowledgeBase, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: +def get_candidates_all( + kb: KnowledgeBase, mentions: Generator[Iterable[Span], None, None] +) -> Generator[Iterable[Iterable[Candidate]], None, None]: """ Return candidate entities for the given mentions and fetching appropriate entries from the index. kb (KnowledgeBase): Knowledge base to query. - mention (Iterable[Span]): Entity mentions for which to identify candidates. - RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. + mention (Generator[Iterable[Span]]): Entity mentions per document for which to identify candidates. + RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ - return kb.get_candidates_batch(mentions) + return kb.get_candidates_all(mentions) diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index ce4bc0138..2e99ea493 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -1,7 +1,7 @@ # cython: infer_types=True, profile=True from pathlib import Path -from typing import Iterable, Tuple, Union +from typing import Iterable, Tuple, Union, Generator from cymem.cymem cimport Pool from .candidate import Candidate @@ -23,22 +23,24 @@ cdef class KnowledgeBase: # Make sure abstract KB is not instantiated. if self.__class__ == KnowledgeBase: raise TypeError( - Errors.E1046.format(cls_name=self.__class__.__name__) + Errors.E1045.format(cls_name=self.__class__.__name__) ) self.vocab = vocab self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates_batch(self, mentions: Iterable[Span]) -> Iterable[Iterable[Candidate]]: + def get_candidates_all(self, mentions: Generator[Iterable[Span]]) -> Generator[Iterable[Iterable[Candidate]]]: """ Return candidate entities for specified texts. Each candidate defines the entity, the original alias, and the prior probability of that alias resolving to that entity. If no candidate is found for a given text, an empty list is returned. - mentions (Iterable[Span]): Mentions for which to get candidates. - RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. + mentions (Generator[Iterable[Span]]): Mentions per documents for which to get candidates. + RETURNS (Generator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ - return [self.get_candidates(span) for span in mentions] + + for doc_mentions in mentions: + yield [self.get_candidates(span) for span in doc_mentions] def get_candidates(self, mention: Span) -> Iterable[Candidate]: """ @@ -49,7 +51,7 @@ cdef class KnowledgeBase: RETURNS (Iterable[Candidate]): Identified candidates. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__) ) def get_vectors(self, entities: Iterable[str]) -> Iterable[Iterable[float]]: @@ -67,7 +69,7 @@ cdef class KnowledgeBase: RETURNS (Iterable[float]): Vector for specified entity. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="get_vector", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="get_vector", name=self.__name__) ) def to_bytes(self, **kwargs) -> bytes: @@ -75,7 +77,7 @@ cdef class KnowledgeBase: RETURNS (bytes): Current state as binary string. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="to_bytes", name=self.__name__) ) def from_bytes(self, bytes_data: bytes, *, exclude: Tuple[str] = tuple()): @@ -84,7 +86,7 @@ cdef class KnowledgeBase: exclude (Tuple[str]): Properties to exclude when restoring KB. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="from_bytes", name=self.__name__) ) def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: @@ -94,7 +96,7 @@ cdef class KnowledgeBase: exclude (Iterable[str]): List of components to exclude. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="to_disk", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="to_disk", name=self.__name__) ) def from_disk(self, path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()) -> None: @@ -104,5 +106,5 @@ cdef class KnowledgeBase: exclude (Iterable[str]): List of components to exclude. """ raise NotImplementedError( - Errors.E1045.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) + Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) ) diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index 4d18d216a..9aac71d40 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -1,12 +1,12 @@ from pathlib import Path -from typing import Optional, Callable, Iterable, List, Tuple +from typing import Optional, Callable, Iterable, List, Tuple, Generator from thinc.types import Floats2d from thinc.api import chain, list2ragged, reduce_mean, residual from thinc.api import Model, Maxout, Linear, tuplify, Ragged from ...util import registry from ...kb import KnowledgeBase, InMemoryLookupKB -from ...kb import Candidate, get_candidates, get_candidates_batch +from ...kb import Candidate, get_candidates, get_candidates_all from ...vocab import Vocab from ...tokens import Span, Doc from ..extract_spans import extract_spans @@ -105,8 +105,8 @@ def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: return get_candidates -@registry.misc("spacy.CandidateBatchGenerator.v1") -def create_candidates_batch() -> Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] +@registry.misc("spacy.CandidateAllGenerator.v1") +def create_candidates_all() -> Callable[ + [KnowledgeBase, Generator[Iterable[Span], None, None]], Generator[Iterable[Iterable[Candidate]], None, None] ]: - return get_candidates_batch + return get_candidates_all diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 62845287b..4d3baf2f3 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,4 +1,4 @@ -from typing import Optional, Iterable, Callable, Dict, Union, List, Any +from typing import Optional, Iterable, Callable, Dict, Union, List, Any, Generator from thinc.types import Floats2d from pathlib import Path from itertools import islice @@ -53,11 +53,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] "incl_context": True, "entity_vector_length": 64, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, - "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, + "get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"}, "overwrite": True, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "use_gold_ents": True, - "candidates_batch_size": 1, + "candidates_doc_mode": False, "threshold": None, }, default_score_weights={ @@ -77,13 +77,14 @@ def make_entity_linker( incl_context: bool, entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], - get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + get_candidates_all: Callable[ + [KnowledgeBase, Generator[Iterable[Span], None, None]], + Generator[Iterable[Iterable[Candidate]], None, None] ], overwrite: bool, scorer: Optional[Callable], use_gold_ents: bool, - candidates_batch_size: int, + candidates_doc_mode: bool, threshold: Optional[float] = None, ): """Construct an EntityLinker component. @@ -98,13 +99,18 @@ def make_entity_linker( entity_vector_length (int): Size of encoding vectors in the KB. get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + get_candidates_all ( + Callable[ + [KnowledgeBase, Generator[Iterable[Span], None, None]], + Generator[Iterable[Iterable[Candidate]], None, None] + ]): Function that produces a list of candidates per document, given a certain knowledge base and several textual + documents with textual mentions. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. - candidates_batch_size (int): Size of batches for entity candidate generation. + candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator + yielding entities per document (candidate generator callable is called only once in this case). If False, + the candidate generator is called once per entity. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. """ @@ -134,11 +140,11 @@ def make_entity_linker( incl_context=incl_context, entity_vector_length=entity_vector_length, get_candidates=get_candidates, - get_candidates_batch=get_candidates_batch, + get_candidates_all=get_candidates_all, overwrite=overwrite, scorer=scorer, use_gold_ents=use_gold_ents, - candidates_batch_size=candidates_batch_size, + candidates_doc_mode=candidates_doc_mode, threshold=threshold, ) @@ -172,13 +178,14 @@ class EntityLinker(TrainablePipe): incl_context: bool, entity_vector_length: int, get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], - get_candidates_batch: Callable[ - [KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]] + get_candidates_all: Callable[ + [KnowledgeBase, Generator[Iterable[Span], None, None]], + Generator[Iterable[Iterable[Candidate]], None, None] ], overwrite: bool = BACKWARD_OVERWRITE, scorer: Optional[Callable] = entity_linker_score, use_gold_ents: bool, - candidates_batch_size: int, + candidates_doc_mode: bool, threshold: Optional[float] = None, ) -> None: """Initialize an entity linker. @@ -194,14 +201,18 @@ class EntityLinker(TrainablePipe): entity_vector_length (int): Size of encoding vectors in the KB. get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_batch ( - Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]], - Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + get_candidates_all ( + Callable[ + [KnowledgeBase, Generator[Iterable[Span], None, None]], + Generator[Iterable[Iterable[Candidate]], None, None] + ]): Function that produces a list of candidates per document, given a certain knowledge base and several textual + documents with textual mentions. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. - candidates_batch_size (int): Size of batches for entity candidate generation. + candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator + yielding entities per document (candidate generator callable is called only once in this case). If False, + the candidate generator is called once per entity. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. DOCS: https://spacy.io/api/entitylinker#init @@ -224,7 +235,7 @@ class EntityLinker(TrainablePipe): self.incl_prior = incl_prior self.incl_context = incl_context self.get_candidates = get_candidates - self.get_candidates_batch = get_candidates_batch + self.get_candidates_all = get_candidates_all self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.distance = CosineDistance(normalize=False) # how many neighbour sentences to take into account @@ -232,12 +243,9 @@ class EntityLinker(TrainablePipe): self.kb = empty_kb(entity_vector_length)(self.vocab) self.scorer = scorer self.use_gold_ents = use_gold_ents - self.candidates_batch_size = candidates_batch_size + self.candidates_doc_mode = candidates_doc_mode self.threshold = threshold - if candidates_batch_size < 1: - raise ValueError(Errors.E1044) - def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): """Define the KB of this pipe by providing a function that will create it using this object's vocab.""" @@ -440,96 +448,98 @@ class EntityLinker(TrainablePipe): return final_kb_ids if isinstance(docs, Doc): docs = [docs] - for i, doc in enumerate(docs): + + # Determine which entities are to be ignored due to labels_discard. + valid_ent_idx_per_doc = ( + [ + idx + for idx in range(len(doc.ents)) + if doc.ents[idx].label_ not in self.labels_discard + ] + for doc in docs if len(doc.ents) + ) + # Call candidate generator. + if self.candidates_doc_mode: + all_ent_cands = self.get_candidates_all( + self.kb, + ([doc.ents[idx] for idx in next(valid_ent_idx_per_doc)] for doc in docs if len(doc.ents)) + ) + else: + # Alternative: collect entities the old-fashioned way - by retrieving entities individually. + all_ent_cands = ( + [self.get_candidates(self.kb, doc.ents[idx]) for idx in next(valid_ent_idx_per_doc)] + for doc in docs if len(doc.ents) + ) + + for doc_idx, doc in enumerate(docs): if len(doc) == 0: continue sentences = [s for s in doc.sents] + doc_ent_cands = list(next(all_ent_cands)) if len(doc.ents) else [] - # Loop over entities in batches. - for ent_idx in range(0, len(doc.ents), self.candidates_batch_size): - ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size] + # Looping over candidate entities for this doc. (TODO: rewrite) + for ent_cand_idx, ent in enumerate(doc.ents): + sent_index = sentences.index(ent.sent) + assert sent_index >= 0 - # Look up candidate entities. - valid_ent_idx = [ - idx - for idx in range(len(ent_batch)) - if ent_batch[idx].label_ not in self.labels_discard - ] - - batch_candidates = list( - self.get_candidates_batch( - self.kb, [ent_batch[idx] for idx in valid_ent_idx] + if self.incl_context: + # get n_neighbour sentences, clipped to the length of the document + start_sentence = max(0, sent_index - self.n_sents) + end_sentence = min( + len(sentences) - 1, sent_index + self.n_sents ) - if self.candidates_batch_size > 1 - else [ - self.get_candidates(self.kb, ent_batch[idx]) - for idx in valid_ent_idx - ] - ) - - # Looping through each entity in batch (TODO: rewrite) - for j, ent in enumerate(ent_batch): - sent_index = sentences.index(ent.sent) - assert sent_index >= 0 - - if self.incl_context: - # get n_neighbour sentences, clipped to the length of the document - start_sentence = max(0, sent_index - self.n_sents) - end_sentence = min( - len(sentences) - 1, sent_index + self.n_sents - ) - start_token = sentences[start_sentence].start - end_token = sentences[end_sentence].end - sent_doc = doc[start_token:end_token].as_doc() - # currently, the context is the same for each entity in a sentence (should be refined) - sentence_encoding = self.model.predict([sent_doc])[0] - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) - entity_count += 1 - if ent.label_ in self.labels_discard: - # ignoring this entity - setting to NIL + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + sent_doc = doc[start_token:end_token].as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) + sentence_encoding = self.model.predict([sent_doc])[0] + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) + entity_count += 1 + if ent.label_ in self.labels_discard: + # ignoring this entity - setting to NIL + final_kb_ids.append(self.NIL) + else: + candidates = list(doc_ent_cands[ent_cand_idx]) + if not candidates: + # no prediction possible for this entity - setting to NIL final_kb_ids.append(self.NIL) + elif len(candidates) == 1 and self.threshold is None: + # shortcut for efficiency reasons: take the 1 candidate + final_kb_ids.append(candidates[0].entity_) else: - candidates = list(batch_candidates[j]) - if not candidates: - # no prediction possible for this entity - setting to NIL - final_kb_ids.append(self.NIL) - elif len(candidates) == 1 and self.threshold is None: - # shortcut for efficiency reasons: take the 1 candidate - final_kb_ids.append(candidates[0].entity_) - else: - random.shuffle(candidates) - # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.incl_prior: - prior_probs = xp.asarray([0.0 for _ in candidates]) - scores = prior_probs - # add in similarity from the context - if self.incl_context: - entity_encodings = xp.asarray( - [c.entity_vector for c in candidates] - ) - entity_norm = xp.linalg.norm(entity_encodings, axis=1) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError( - Errors.E147.format( - method="predict", - msg="vectors not of equal length", - ) - ) - # cosine similarity - sims = xp.dot(entity_encodings, sentence_encoding_t) / ( - sentence_norm * entity_norm - ) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = prior_probs + sims - (prior_probs * sims) - final_kb_ids.append( - candidates[scores.argmax().item()].entity_ - if self.threshold is None - or scores.max() >= self.threshold - else EntityLinker.NIL + random.shuffle(candidates) + # set all prior probabilities to 0 if incl_prior=False + prior_probs = xp.asarray([c.prior_prob for c in candidates]) + if not self.incl_prior: + prior_probs = xp.asarray([0.0 for _ in candidates]) + scores = prior_probs + # add in similarity from the context + if self.incl_context: + entity_encodings = xp.asarray( + [c.entity_vector for c in candidates] ) + entity_norm = xp.linalg.norm(entity_encodings, axis=1) + if len(entity_encodings) != len(prior_probs): + raise RuntimeError( + Errors.E147.format( + method="predict", + msg="vectors not of equal length", + ) + ) + # cosine similarity + sims = xp.dot(entity_encodings, sentence_encoding_t) / ( + sentence_norm * entity_norm + ) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = prior_probs + sims - (prior_probs * sims) + final_kb_ids.append( + candidates[scores.argmax().item()].entity_ + if self.threshold is None + or scores.max() >= self.threshold + else EntityLinker.NIL + ) if not (len(final_kb_ids) == entity_count): err = Errors.E147.format( diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 4d683acc5..a579b0fac 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Dict, Any +from typing import Callable, Iterable, Dict, Any, Generator import pytest from numpy.testing import assert_equal @@ -497,11 +497,14 @@ def test_el_pipe_configuration(nlp): assert doc[1].ent_kb_id_ == "" assert doc[2].ent_kb_id_ == "Q2" + # Replace the pipe with a new one with with a different candidate generator. + def get_lowercased_candidates(kb, span): return kb.get_alias_candidates(span.text.lower()) - def get_lowercased_candidates_batch(kb, spans): - return [get_lowercased_candidates(kb, span) for span in spans] + def get_lowercased_candidates_all(kb, spans_per_doc): + for doc_spans in spans_per_doc: + yield [get_lowercased_candidates(kb, span) for span in doc_spans] @registry.misc("spacy.LowercaseCandidateGenerator.v1") def create_candidates() -> Callable[ @@ -509,29 +512,39 @@ def test_el_pipe_configuration(nlp): ]: return get_lowercased_candidates - @registry.misc("spacy.LowercaseCandidateBatchGenerator.v1") + @registry.misc("spacy.LowercaseCandidateAllGenerator.v1") def create_candidates_batch() -> Callable[ - [InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[Candidate]] + [InMemoryLookupKB, Generator[Iterable["Span"], None, None]], + Generator[Iterable[Iterable[Candidate]], None, None] ]: - return get_lowercased_candidates_batch + return get_lowercased_candidates_all - # replace the pipe with a new one with with a different candidate generator - entity_linker = nlp.replace_pipe( - "entity_linker", - "entity_linker", - config={ - "incl_context": False, - "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, - "get_candidates_batch": { - "@misc": "spacy.LowercaseCandidateBatchGenerator.v1" + def test_reconfigured_el(candidates_doc_mode: bool, doc_text: str) -> None: + """Test reconfigured EL for correct results. + candidates_doc_mode (bool): candidates_doc_mode in pipe config. + doc_text (str): Text to infer. + """ + _entity_linker = nlp.replace_pipe( + "entity_linker", + "entity_linker", + config={ + "incl_context": False, + "candidates_doc_mode": candidates_doc_mode, + "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, + "get_candidates_all": { + "@misc": "spacy.LowercaseCandidateAllGenerator.v1" + }, }, - }, - ) - entity_linker.set_kb(create_kb) - doc = nlp(text) - assert doc[0].ent_kb_id_ == "Q2" - assert doc[1].ent_kb_id_ == "" - assert doc[2].ent_kb_id_ == "Q2" + ) + _entity_linker.set_kb(create_kb) + _doc = nlp(doc_text) + assert _doc[0].ent_kb_id_ == "Q2" + assert _doc[1].ent_kb_id_ == "" + assert _doc[2].ent_kb_id_ == "Q2" + + # Test individual and doc-wise candidate generation. + test_reconfigured_el(False, text) + test_reconfigured_el(True, text) def test_nel_nsents(nlp): @@ -670,6 +683,7 @@ def test_preserving_links_asdoc(nlp): assert s_ent.kb_id_ == orig_kb_id + def test_preserving_links_ents(nlp): """Test that doc.ents preserves KB annotations""" text = "She lives in Boston. He lives in Denver."