diff --git a/spacy/errors.py b/spacy/errors.py index eadbf63d6..d0676fb02 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -940,7 +940,7 @@ class Errors(metaclass=ErrorsWithCodes): "case pass an empty list for the previously not specified argument to avoid this error.") E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got " "{value}.") - E1044 = ("Expected `candidates_batch_size` to be >= 1, but got: {value}") + # E1044 is unused E1045 = ("Encountered {parent} subclass without `{parent}.{method}` " "method in '{name}'. If you want to use this method, make " "sure it's overwritten on the subclass.") diff --git a/spacy/kb/candidate.py b/spacy/kb/candidate.py index 3cc3a6c59..22c054ab2 100644 --- a/spacy/kb/candidate.py +++ b/spacy/kb/candidate.py @@ -1,5 +1,5 @@ import abc -from typing import List, Union, Callable +from typing import List, Callable class Candidate(abc.ABC): diff --git a/spacy/kb/kb.pyx b/spacy/kb/kb.pyx index e374cf94d..636fca74d 100644 --- a/spacy/kb/kb.pyx +++ b/spacy/kb/kb.pyx @@ -1,7 +1,7 @@ # cython: infer_types=True, profile=True from pathlib import Path -from typing import Iterable, Tuple, Union +from typing import Iterable, Tuple, Union, Iterator from cymem.cymem cimport Pool from .candidate import Candidate @@ -30,23 +30,13 @@ cdef class KnowledgeBase: self.entity_vector_length = entity_vector_length self.mem = Pool() - def get_candidates_batch(self, mentions: SpanGroup) -> Iterable[Iterable[Candidate]]: + def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[Candidate]]]: """ - Return candidate entities for specified texts. Each candidate defines the entity, the original alias, - and the prior probability of that alias resolving to that entity. - If no candidate is found for a given text, an empty list is returned. - mentions (SpanGroup): Mentions for which to get candidates. - RETURNS (Iterable[Iterable[Candidate]]): Identified candidates. - """ - return [self.get_candidates(span) for span in mentions] - - def get_candidates(self, mention: Span) -> Iterable[Candidate]: - """ - Return candidate entities for specified text. Each candidate defines the entity, the original alias, - and the prior probability of that alias resolving to that entity. - If the no candidate is found for a given text, an empty list is returned. - mention (Span): Mention for which to get candidates. - RETURNS (Iterable[Candidate]): Identified candidates. + Return candidate entities for mentions stored in `ent` attribute in passed docs. Each candidate defines the + entity, the original alias, and the prior probability of that alias resolving to that entity. + If no candidate is found for a given mention, an empty list is returned. + mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance. + RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ raise NotImplementedError( Errors.E1045.format(parent="KnowledgeBase", method="get_candidates", name=self.__name__) diff --git a/spacy/kb/kb_in_memory.pyx b/spacy/kb/kb_in_memory.pyx index f39432f5e..90726a98e 100644 --- a/spacy/kb/kb_in_memory.pyx +++ b/spacy/kb/kb_in_memory.pyx @@ -1,5 +1,5 @@ # cython: infer_types=True, profile=True -from typing import Iterable, Callable, Dict, Any, Union +from typing import Iterable, Callable, Dict, Any, Union, Iterator import srsly from preshed.maps cimport PreshMap @@ -11,7 +11,7 @@ from libcpp.vector cimport vector from pathlib import Path import warnings -from ..tokens import Span +from ..tokens import SpanGroup from ..typedefs cimport hash_t from ..errors import Errors, Warnings from .. import util @@ -223,8 +223,9 @@ cdef class InMemoryLookupKB(KnowledgeBase): alias_entry.probs = probs self._aliases_table[alias_index] = alias_entry - def get_candidates(self, mention: Span) -> Iterable[InMemoryCandidate]: - return self.get_alias_candidates(mention.text) # type: ignore + def get_candidates(self, mentions: Iterator[SpanGroup]) -> Iterator[Iterable[Iterable[InMemoryCandidate]]]: + for mentions_for_doc in mentions: + yield [self.get_alias_candidates(ent_span.text) for ent_span in mentions_for_doc] def get_alias_candidates(self, str alias) -> Iterable[InMemoryCandidate]: """ diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index fd19af5ab..2910ace6d 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Optional, Callable, Iterable, List, Tuple +from typing import Optional, Callable, Iterable, List, Tuple, Iterator from thinc.types import Floats2d from thinc.api import chain, list2ragged, reduce_mean, residual from thinc.api import Model, Maxout, Linear, tuplify, Ragged @@ -100,34 +100,20 @@ def empty_kb( @registry.misc("spacy.CandidateGenerator.v1") -def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: +def create_candidates_all() -> Callable[ + [KnowledgeBase, Iterator[SpanGroup]], + Iterator[Iterable[Iterable[Candidate]]], +]: return get_candidates -@registry.misc("spacy.CandidateBatchGenerator.v1") -def create_candidates_batch() -> Callable[ - [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] -]: - return get_candidates_batch - - -def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: - """ - Return candidate entities for a given mention and fetching appropriate entries from the index. - kb (KnowledgeBase): Knowledge base to query. - mention (Span): Entity mention for which to identify candidates. - RETURNS (Iterable[InMemoryCandidate]): Identified candidates. - """ - return kb.get_candidates(mention) - - -def get_candidates_batch( - kb: KnowledgeBase, mentions: SpanGroup -) -> Iterable[Iterable[Candidate]]: +def get_candidates( + kb: KnowledgeBase, mentions: Iterator[SpanGroup] +) -> Iterator[Iterable[Iterable[Candidate]]]: """ Return candidate entities for the given mentions and fetching appropriate entries from the index. kb (KnowledgeBase): Knowledge base to query. - mention (SpanGroup): Entity mentions for which to identify candidates. - RETURNS (Iterable[Iterable[InMemoryCandidate]]): Identified candidates. + mentions (Iterator[SpanGroup]): Mentions per doc as SpanGroup instance. + RETURNS (Iterator[Iterable[Iterable[Candidate]]]): Identified candidates per document. """ - return kb.get_candidates_batch(mentions) + return kb.get_candidates(mentions) diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index a38f6b95b..09ee7053b 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -1,4 +1,4 @@ -from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any +from typing import Optional, Iterable, Callable, Dict, Sequence, Union, List, Any, Iterator from typing import cast from numpy import dtype from thinc.types import Floats1d, Floats2d, Ints1d, Ragged @@ -9,10 +9,9 @@ import random from thinc.api import CosineDistance, Model, Optimizer, Config from thinc.api import set_dropout_rate -from ..tokens import SpanGroup from ..kb import KnowledgeBase, Candidate from ..ml import empty_kb -from ..tokens import Doc, Span, SpanGroup +from ..tokens import Doc, SpanGroup from .pipe import deserialize_config from .trainable_pipe import TrainablePipe from ..language import Language @@ -57,11 +56,9 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] "incl_context": True, "entity_vector_length": 64, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, - "get_candidates_batch": {"@misc": "spacy.CandidateBatchGenerator.v1"}, "overwrite": False, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "use_gold_ents": True, - "candidates_batch_size": 1, "threshold": None, "save_activations": False, }, @@ -81,14 +78,10 @@ def make_entity_linker( incl_prior: bool, incl_context: bool, entity_vector_length: int, - get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], - get_candidates_batch: Callable[ - [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] - ], + get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]], overwrite: bool, scorer: Optional[Callable], use_gold_ents: bool, - candidates_batch_size: int, threshold: Optional[float] = None, save_activations: bool, ): @@ -102,15 +95,12 @@ def make_entity_linker( incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. - get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that - produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_batch ( - Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]): + Function producing a list of candidates per document, given a certain knowledge base and several textual + documents with textual mentions. scorer (Optional[Callable]): The scoring method. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. - candidates_batch_size (int): Size of batches for entity candidate generation. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. save_activations (bool): save model activations in Doc when annotating. @@ -147,11 +137,9 @@ def make_entity_linker( incl_context=incl_context, entity_vector_length=entity_vector_length, get_candidates=get_candidates, - get_candidates_batch=get_candidates_batch, overwrite=overwrite, scorer=scorer, use_gold_ents=use_gold_ents, - candidates_batch_size=candidates_batch_size, threshold=threshold, save_activations=save_activations, ) @@ -185,14 +173,10 @@ class EntityLinker(TrainablePipe): incl_prior: bool, incl_context: bool, entity_vector_length: int, - get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]], - get_candidates_batch: Callable[ - [KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]] - ], + get_candidates: Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]], overwrite: bool = False, scorer: Optional[Callable] = entity_linker_score, use_gold_ents: bool, - candidates_batch_size: int, threshold: Optional[float] = None, save_activations: bool = False, ) -> None: @@ -207,17 +191,13 @@ class EntityLinker(TrainablePipe): incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_context (bool): Whether or not to include the local context in the model. entity_vector_length (int): Size of encoding vectors in the KB. - get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that - produces a list of candidates, given a certain knowledge base and a textual mention. - get_candidates_batch ( - Callable[[KnowledgeBase, SpanGroup], Iterable[Iterable[Candidate]]], - Iterable[Candidate]] - ): Function that produces a list of candidates, given a certain knowledge base and several textual mentions. + get_candidates (Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]): + Function producing a list of candidates per document, given a certain knowledge base and several textual + documents with textual mentions. overwrite (bool): Whether to overwrite existing non-empty annotations. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another component must provide entity annotations. - candidates_batch_size (int): Size of batches for entity candidate generation. threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold, prediction is discarded. If None, predictions are not filtered by any threshold. DOCS: https://spacy.io/api/entitylinker#init @@ -240,7 +220,6 @@ class EntityLinker(TrainablePipe): self.incl_prior = incl_prior self.incl_context = incl_context self.get_candidates = get_candidates - self.get_candidates_batch = get_candidates_batch self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.distance = CosineDistance(normalize=False) # how many neighbour sentences to take into account @@ -248,13 +227,9 @@ class EntityLinker(TrainablePipe): self.kb = empty_kb(entity_vector_length)(self.vocab) self.scorer = scorer self.use_gold_ents = use_gold_ents - self.candidates_batch_size = candidates_batch_size self.threshold = threshold self.save_activations = save_activations - if candidates_batch_size < 1: - raise ValueError(Errors.E1044) - def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]): """Define the KB of this pipe by providing a function that will create it using this object's vocab.""" @@ -331,11 +306,12 @@ class EntityLinker(TrainablePipe): If one isn't present, then the update step needs to be skipped. """ - - for eg in examples: - for ent in eg.predicted.ents: - candidates = list(self.get_candidates(self.kb, ent)) - if candidates: + # todo continue here: fix get_candidates_call + for candidates_for_doc in self.get_candidates( + self.kb, (SpanGroup(doc=eg.predicted, spans=eg.predicted.ents) for eg in examples) + ): + for candidates_for_mention in candidates_for_doc: + if list(candidates_for_mention): return True return False @@ -464,59 +440,72 @@ class EntityLinker(TrainablePipe): } if isinstance(docs, Doc): docs = [docs] - for doc in docs: + + docs = list(docs) + # Determine which entities are to be ignored due to labels_discard. + valid_ent_idx_per_doc = ( + [ + idx + for idx in range(len(doc.ents)) + if doc.ents[idx].label_ not in self.labels_discard + ] + for doc in docs + if len(doc) and len(doc.ents) + ) + + # Call candidate generator. + all_ent_cands = self.get_candidates( + self.kb, + ( + SpanGroup( + doc, + spans=[doc.ents[idx] for idx in next(valid_ent_idx_per_doc)], + ) + for doc in docs + if len(doc) and len(doc.ents) + ), + ) + + for doc_idx, doc in enumerate(docs): doc_ents: List[Ints1d] = [] doc_scores: List[Floats1d] = [] - if len(doc) == 0: + if len(doc) == 0 or len(doc.ents) == 0: docs_scores.append(Ragged(ops.alloc1f(0), ops.alloc1i(0))) docs_ents.append(Ragged(xp.zeros(0, dtype="uint64"), ops.alloc1i(0))) continue sentences = [s for s in doc.sents] + doc_ent_cands = list(next(all_ent_cands)) - # Loop over entities in batches. - for ent_idx in range(0, len(doc.ents), self.candidates_batch_size): - ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size] + # Looping over candidate entities for this doc. (TODO: rewrite) + for ent_cand_idx, ent in enumerate(doc.ents): + sent_index = sentences.index(ent.sent) + assert sent_index >= 0 - # Look up candidate entities. - valid_ent_idx = [ - idx - for idx in range(len(ent_batch)) - if ent_batch[idx].label_ not in self.labels_discard - ] - - batch_candidates = list( - self.get_candidates_batch( - self.kb, - SpanGroup(doc, spans=[ent_batch[idx] for idx in valid_ent_idx]), + if self.incl_context: + # get n_neighbour sentences, clipped to the length of the document + start_sentence = max(0, sent_index - self.n_sents) + end_sentence = min(len(sentences) - 1, sent_index + self.n_sents) + start_token = sentences[start_sentence].start + end_token = sentences[end_sentence].end + sent_doc = doc[start_token:end_token].as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) + sentence_encoding = self.model.predict([sent_doc])[0] + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) + entity_count += 1 + if ent.label_ in self.labels_discard: + # ignoring this entity - setting to NIL + final_kb_ids.append(self.NIL) + self._add_activations( + doc_scores=doc_scores, + doc_ents=doc_ents, + scores=[0.0], + ents=[0], ) - if self.candidates_batch_size > 1 - else [ - self.get_candidates(self.kb, ent_batch[idx]) - for idx in valid_ent_idx - ] - ) - - # Looping through each entity in batch (TODO: rewrite) - for j, ent in enumerate(ent_batch): - sent_index = sentences.index(ent.sent) - assert sent_index >= 0 - - if self.incl_context: - # get n_neighbour sentences, clipped to the length of the document - start_sentence = max(0, sent_index - self.n_sents) - end_sentence = min( - len(sentences) - 1, sent_index + self.n_sents - ) - start_token = sentences[start_sentence].start - end_token = sentences[end_sentence].end - sent_doc = doc[start_token:end_token].as_doc() - # currently, the context is the same for each entity in a sentence (should be refined) - sentence_encoding = self.model.predict([sent_doc])[0] - sentence_encoding_t = sentence_encoding.T - sentence_norm = xp.linalg.norm(sentence_encoding_t) - entity_count += 1 - if ent.label_ in self.labels_discard: - # ignoring this entity - setting to NIL + else: + candidates = list(doc_ent_cands[ent_cand_idx]) + if not candidates: + # no prediction possible for this entity - setting to NIL final_kb_ids.append(self.NIL) self._add_activations( doc_scores=doc_scores, @@ -524,65 +513,56 @@ class EntityLinker(TrainablePipe): scores=[0.0], ents=[0], ) + elif len(candidates) == 1 and self.threshold is None: + # shortcut for efficiency reasons: take the 1 candidate + final_kb_ids.append(candidates[0].entity_) + self._add_activations( + doc_scores=doc_scores, + doc_ents=doc_ents, + scores=[1.0], + ents=[candidates[0].entity], + ) else: - candidates = list(batch_candidates[j]) - if not candidates: - # no prediction possible for this entity - setting to NIL - final_kb_ids.append(self.NIL) - self._add_activations( - doc_scores=doc_scores, - doc_ents=doc_ents, - scores=[0.0], - ents=[0], + random.shuffle(candidates) + # set all prior probabilities to 0 if incl_prior=False + scores = prior_probs = xp.asarray( + [ + c.prior_prob if self.incl_prior else 0.0 + for c in candidates + ] + ) + # add in similarity from the context + if self.incl_context: + entity_encodings = xp.asarray( + [c.entity_vector for c in candidates] ) - elif len(candidates) == 1 and self.threshold is None: - # shortcut for efficiency reasons: take the 1 candidate - final_kb_ids.append(candidates[0].entity_) - self._add_activations( - doc_scores=doc_scores, - doc_ents=doc_ents, - scores=[1.0], - ents=[candidates[0].entity], - ) - else: - random.shuffle(candidates) - # set all prior probabilities to 0 if incl_prior=False - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.incl_prior: - prior_probs = xp.asarray([0.0 for _ in candidates]) - scores = prior_probs - # add in similarity from the context - if self.incl_context: - entity_encodings = xp.asarray( - [c.entity_vector for c in candidates] - ) - entity_norm = xp.linalg.norm(entity_encodings, axis=1) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError( - Errors.E147.format( - method="predict", - msg="vectors not of equal length", - ) + entity_norm = xp.linalg.norm(entity_encodings, axis=1) + if len(entity_encodings) != len(prior_probs): + raise RuntimeError( + Errors.E147.format( + method="predict", + msg="vectors not of equal length", ) - # cosine similarity - sims = xp.dot(entity_encodings, sentence_encoding_t) / ( - sentence_norm * entity_norm ) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = prior_probs + sims - (prior_probs * sims) - final_kb_ids.append( - candidates[scores.argmax().item()].entity_ - if self.threshold is None - or scores.max() >= self.threshold - else EntityLinker.NIL - ) - self._add_activations( - doc_scores=doc_scores, - doc_ents=doc_ents, - scores=scores, - ents=[c.entity for c in candidates], + # cosine similarity + sims = xp.dot(entity_encodings, sentence_encoding_t) / ( + sentence_norm * entity_norm ) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = prior_probs + sims - (prior_probs * sims) + final_kb_ids.append( + candidates[scores.argmax().item()].entity_ + if self.threshold is None or scores.max() >= self.threshold + else EntityLinker.NIL + ) + self._add_activations( + doc_scores=doc_scores, + doc_ents=doc_ents, + scores=scores, + ents=[c.entity for c in candidates], + ) + self._add_doc_activations( docs_scores=docs_scores, docs_ents=docs_ents, @@ -594,6 +574,7 @@ class EntityLinker(TrainablePipe): method="predict", msg="result variables not of equal length" ) raise RuntimeError(err) + return { KNOWLEDGE_BASE_IDS: final_kb_ids, "ents": docs_ents, diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index cb1e4a733..82f620339 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -1,4 +1,4 @@ -from typing import Callable, Iterable, Dict, Any, cast +from typing import Callable, Iterable, Dict, Any, cast, Iterator import pytest from numpy.testing import assert_equal @@ -15,7 +15,7 @@ from spacy.pipeline import EntityLinker, TrainablePipe from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.scorer import Scorer from spacy.tests.util import make_tempdir -from spacy.tokens import Span, Doc +from spacy.tokens import Doc, Span, SpanGroup from spacy.training import Example from spacy.util import ensure_path from spacy.vocab import Vocab @@ -462,16 +462,17 @@ def test_candidate_generation(nlp): mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) # test the size of the relevant candidates - assert len(get_candidates(mykb, douglas_ent)) == 2 - assert len(get_candidates(mykb, adam_ent)) == 1 - assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive - assert len(get_candidates(mykb, shrubbery_ent)) == 0 + adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0] + assert len(adam_ent_cands) == 1 + assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2 + assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 # default case sensitive + assert len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) == 0 # test the content of the candidates - assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2" - assert get_candidates(mykb, adam_ent)[0].alias_ == "adam" - assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12) - assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9) + assert adam_ent_cands[0].entity_ == "Q2" + assert adam_ent_cands[0].alias_ == "adam" + assert_almost_equal(adam_ent_cands[0].entity_freq, 12) + assert_almost_equal(adam_ent_cands[0].prior_prob, 0.9) def test_el_pipe_configuration(nlp): @@ -498,24 +499,16 @@ def test_el_pipe_configuration(nlp): assert doc[1].ent_kb_id_ == "" assert doc[2].ent_kb_id_ == "Q2" - def get_lowercased_candidates(kb, span): - return kb.get_alias_candidates(span.text.lower()) - - def get_lowercased_candidates_batch(kb, spans): - return [get_lowercased_candidates(kb, span) for span in spans] + def get_lowercased_candidates(kb: InMemoryLookupKB, mentions: Iterator[SpanGroup]): + for mentions_for_doc in mentions: + yield [kb.get_alias_candidates(ent_span.text.lower()) for ent_span in mentions_for_doc] @registry.misc("spacy.LowercaseCandidateGenerator.v1") def create_candidates() -> Callable[ - [InMemoryLookupKB, "Span"], Iterable[InMemoryCandidate] + [InMemoryLookupKB, Iterator[SpanGroup]], Iterator[Iterable[Iterable[InMemoryCandidate]]] ]: return get_lowercased_candidates - @registry.misc("spacy.LowercaseCandidateBatchGenerator.v1") - def create_candidates_batch() -> Callable[ - [InMemoryLookupKB, Iterable["Span"]], Iterable[Iterable[InMemoryCandidate]] - ]: - return get_lowercased_candidates_batch - # replace the pipe with a new one with with a different candidate generator entity_linker = nlp.replace_pipe( "entity_linker", @@ -523,9 +516,6 @@ def test_el_pipe_configuration(nlp): config={ "incl_context": False, "get_candidates": {"@misc": "spacy.LowercaseCandidateGenerator.v1"}, - "get_candidates_batch": { - "@misc": "spacy.LowercaseCandidateBatchGenerator.v1" - }, }, ) entity_linker.set_kb(create_kb) diff --git a/spacy/util.py b/spacy/util.py index 2ce2e5e0f..d815855d3 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -10,7 +10,7 @@ import re from pathlib import Path import thinc from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer -from thinc.api import ConfigValidationError, Model, constant as constant_schedule +from thinc.api import ConfigValidationError, Model import functools import itertools import numpy diff --git a/website/docs/api/entitylinker.mdx b/website/docs/api/entitylinker.mdx index 12b2f6bef..8948d3d29 100644 --- a/website/docs/api/entitylinker.mdx +++ b/website/docs/api/entitylinker.mdx @@ -62,7 +62,7 @@ architectures and their arguments and hyperparameters. | `model` | The [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. Defaults to [EntityLinker](/api/architectures#EntityLinker). ~~Model~~ | | `entity_vector_length` | Size of encoding vectors in the KB. Defaults to `64`. ~~int~~ | | `use_gold_ents` | Whether to copy entities from the gold docs or not. Defaults to `True`. If `False`, entities must be set in the training data or by an annotating component in the pipeline. ~~int~~ | -| `get_candidates` | Function that generates plausible candidates for a given `Span` object. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator), a function looking up exact, case-dependent aliases in the KB. ~~Callable[[KnowledgeBase, Span], Iterable[Candidate]]~~ | +| `get_candidates` | Function that retrieves plausible candidates per entity mention in a given `SpanGroup`. Defaults to [CandidateGenerator](/api/architectures#CandidateGenerator). ~~Callable[[KnowledgeBase, Iterator[SpanGroup]], Iterator[Iterable[Iterable[Candidate]]]]~~ | | `overwrite` 3.2 | Whether existing annotation is overwritten. Defaults to `False`. ~~bool~~ | | `scorer` 3.2 | The scoring method. Defaults to [`Scorer.score_links`](/api/scorer#score_links). ~~Optional[Callable]~~ | | `save_activations` 4.0 | Save activations in `Doc` when annotating. Saved activations are `"ents"` and `"scores"`. ~~Union[bool, list[str]]~~ | diff --git a/website/docs/api/inmemorylookupkb.mdx b/website/docs/api/inmemorylookupkb.mdx index 64ee8cc36..789c28293 100644 --- a/website/docs/api/inmemorylookupkb.mdx +++ b/website/docs/api/inmemorylookupkb.mdx @@ -155,35 +155,13 @@ Get a list of all aliases in the knowledge base. ## InMemoryLookupKB.get_candidates {id="get_candidates",tag="method"} -Given a certain textual mention as input, retrieve a list of candidate entities -of type [`InMemoryCandidate`](/api/kb#candidate). Wraps -[`get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates). - -> #### Example -> -> ```python -> from spacy.lang.en import English -> nlp = English() -> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.") -> candidates = kb.get_candidates(doc[0:2]) -> ``` - -| Name | Description | -| ----------- | ------------------------------------------------------------------------------------ | -| `mention` | The textual mention or alias. ~~Span~~ | -| **RETURNS** | An iterable of relevant `InMemoryCandidate` objects. ~~Iterable[InMemoryCandidate]~~ | - -## InMemoryLookupKB.get_candidates_batch {id="get_candidates_batch",tag="method"} - -Same as [`get_candidates()`](/api/inmemorylookupkb#get_candidates), but for an -arbitrary number of mentions. The [`EntityLinker`](/api/entitylinker) component -will call `get_candidates_batch()` instead of `get_candidates()`, if the config -parameter `candidates_batch_size` is greater or equal than 1. - -The default implementation of `get_candidates_batch()` executes -`get_candidates()` in a loop. We recommend implementing a more efficient way to -retrieve candidates for multiple mentions at once, if performance is of concern -to you. +Given textual mentions for an arbitrary number of documents as input, retrieve a +list of candidate entities of type [`InMemoryCandidate`](/api/kb#candidate) for +each mention. The [`EntityLinker`](/api/entitylinker) component passes a +generator yielding all mentions to retreive candidates for as +[`SpanGroup`](/api/spangroup)) per document. The decision of how to batch +candidate retrieval lookups over multiple documents is left up to the +implementation of `KnowledgeBase.get_candidates()`. > #### Example > @@ -192,13 +170,13 @@ to you. > from spacy.tokens import SpanGroup > nlp = English() > doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.") -> candidates = kb.get_candidates(SpanGroup(doc, spans=[doc[0:2], doc[3:]]) +> candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]]) > ``` -| Name | Description | -| ----------- | ------------------------------------------------------------------------------------------------------------ | -| `mentions` | The textual mention or alias. ~~Iterable[Span]~~ | -| **RETURNS** | An iterable of iterable with relevant `InMemoryCandidate` objects. ~~Iterable[Iterable[InMemoryCandidate]]~~ | +| Name | Description | +| ----------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `mentions` | The textual mention or alias. ~~Iterable[SpanGroup]~~ | +| **RETURNS** | An iterator over iterables of iterables with relevant [`InMemoryCandidate`](/api/kb#candidate) objects (per mention and doc). ~~Iterator[Iterable[Iterable[InMemoryCandidate]]]~~ | ## InMemoryLookupKB.get_alias_candidates {id="get_alias_candidates",tag="method"} diff --git a/website/docs/api/kb.mdx b/website/docs/api/kb.mdx index 12cca7d66..82e5979d9 100644 --- a/website/docs/api/kb.mdx +++ b/website/docs/api/kb.mdx @@ -60,34 +60,13 @@ The length of the fixed-size entity vectors in the knowledge base. ## KnowledgeBase.get_candidates {id="get_candidates",tag="method"} -Given a certain textual mention as input, retrieve a list of candidate entities -of type [`Candidate`](/api/kb#candidate). - -> #### Example -> -> ```python -> from spacy.lang.en import English -> nlp = English() -> doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.") -> candidates = kb.get_candidates(doc[0:2]) -> ``` - -| Name | Description | -| ----------- | -------------------------------------------------------------------- | -| `mention` | The textual mention or alias. ~~Span~~ | -| **RETURNS** | An iterable of relevant `Candidate` objects. ~~Iterable[Candidate]~~ | - -## KnowledgeBase.get_candidates_batch {id="get_candidates_batch",tag="method"} - -Same as [`get_candidates()`](/api/kb#get_candidates), but for an arbitrary -number of mentions. The [`EntityLinker`](/api/entitylinker) component will call -`get_candidates_batch()` instead of `get_candidates()`, if the config parameter -`candidates_batch_size` is greater or equal than 1. - -The default implementation of `get_candidates_batch()` executes -`get_candidates()` in a loop. We recommend implementing a more efficient way to -retrieve candidates for multiple mentions at once, if performance is of concern -to you. +Given textual mentions for an arbitrary number of documents as input, retrieve a +list of candidate entities of type [`Candidate`](/api/kb#candidate) for each +mention. The [`EntityLinker`](/api/entitylinker) component passes a generator +yielding all mentions to retreive candidates for as +[`SpanGroup`](/api/spangroup)) per document. The decision of how to batch +candidate retrieval lookups over multiple documents is left up to the +implementation of `KnowledgeBase.get_candidates()`. > #### Example > @@ -96,30 +75,13 @@ to you. > from spacy.tokens import SpanGroup > nlp = English() > doc = nlp("Douglas Adams wrote 'The Hitchhiker's Guide to the Galaxy'.") -> candidates = kb.get_candidates(SpanGroup(doc, spans=[doc[0:2], doc[3:]]) +> candidates = kb.get_candidates([SpanGroup(doc, spans=[doc[0:2], doc[3:]]]) > ``` -| Name | Description | -| ----------- | -------------------------------------------------------------------------------------------- | -| `mentions` | The textual mention or alias. ~~SpanGroup~~ | -| **RETURNS** | An iterable of iterable with relevant `Candidate` objects. ~~Iterable[Iterable[Candidate]]~~ | - -## KnowledgeBase.get_alias_candidates {id="get_alias_candidates",tag="method"} - - - This method is _not_ available from spaCy 3.5 onwards. - - -From spaCy 3.5 on `KnowledgeBase` is an abstract class (with -[`InMemoryLookupKB`](/api/inmemorylookupkb) being a drop-in replacement) to -allow more flexibility in customizing knowledge bases. Some of its methods were -moved to [`InMemoryLookupKB`](/api/inmemorylookupkb) during this refactoring, -one of those being `get_alias_candidates()`. This method is now available as -[`InMemoryLookupKB.get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates). -Note: -[`InMemoryLookupKB.get_candidates()`](/api/inmemorylookupkb#get_candidates) -defaults to -[`InMemoryLookupKB.get_alias_candidates()`](/api/inmemorylookupkb#get_alias_candidates). +| Name | Description | +| ----------- | -------------------------------------------------------------------------------------------------------------------------------------------- | +| `mentions` | The textual mention or alias. ~~Iterable[SpanGroup]~~ | +| **RETURNS** | An iterator over iterables of iterables with relevant `Candidate` objects (per mention and doc). ~~Iterator[Iterable[Iterable[Candidate]]]~~ | ## KnowledgeBase.get_vector {id="get_vector",tag="method"} @@ -193,11 +155,11 @@ Restore the state of the knowledge base from a given directory. Note that the ## InMemoryCandidate {id="candidate",tag="class"} -A `InMemoryCandidate` object refers to a textual mention (alias) that may or may -not be resolved to a specific entity from a `KnowledgeBase`. This will be used -as input for the entity linking algorithm which will disambiguate the various -candidates to the correct one. Each candidate `(alias, entity)` pair is assigned -to a certain prior probability. +A `InMemoryCandidate` object refers to a textual mention that may or may not be +resolved to a specific entity from a `KnowledgeBase`. This will be used as input +for the entity linking algorithm which will disambiguate the various candidates +to the correct one. Each candidate `(mention, entity)` pair is assigned to a +certain prior probability. ### InMemoryCandidate.\_\_init\_\_ {id="candidate-init",tag="method"}