diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 2817147f3..65b0b3358 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -238,7 +238,7 @@ grad_factor = 1.0 {% if "entity_linker" in components -%} [components.entity_linker] factory = "entity_linker" -get_candidates = {"@misc":"spacy.CandidateGenerator.v1"} +get_candidates = {"@misc":"spacy.CandidateGenerator.v2"} incl_context = true incl_prior = true @@ -517,7 +517,7 @@ width = ${components.tok2vec.model.encode.width} {% if "entity_linker" in components -%} [components.entity_linker] factory = "entity_linker" -get_candidates = {"@misc":"spacy.CandidateGenerator.v1"} +get_candidates = {"@misc":"spacy.CandidateGenerator.v2"} incl_context = true incl_prior = true diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index f06876aa4..dc2009059 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -120,13 +120,28 @@ def empty_kb( @registry.misc("spacy.CandidateGenerator.v1") -def create_get_candidates() -> Callable[ - [KnowledgeBase, Iterator[SpanGroup]], Iterator[CandidatesForDocT] -]: +def create_get_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: return get_candidates -def get_candidates( +@registry.misc("spacy.CandidateGenerator.v2") +def create_get_candidates_v2() -> Callable[ + [KnowledgeBase, Iterator[SpanGroup]], Iterator[CandidatesForDocT] +]: + return get_candidates_v2 + + +def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]: + """ + Return candidate entities for the given mention from the KB. + kb (KnowledgeBase): Knowledge base to query. + mention (Span): Entity mention. + RETURNS (Iterable[Candidate]): Identified candidates for specified mention. + """ + return next(next(get_candidates_v2(kb, iter([SpanGroup(mention.doc, spans=[mention])])))[0]) + + +def get_candidates_v2( kb: KnowledgeBase, mentions: Iterator[SpanGroup] ) -> Iterator[Iterable[Iterable[Candidate]]]: """ diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 07534c523..ecad6484e 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -65,7 +65,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] "incl_prior": True, "incl_context": True, "entity_vector_length": 64, - "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, + "get_candidates": {"@misc": "spacy.CandidateGenerator.v2"}, "overwrite": False, "generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"}, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 571887970..955bae922 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -10,7 +10,7 @@ from spacy.compat import pickle from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase from spacy.lang.en import English from spacy.ml import load_kb -from spacy.ml.models.entity_linker import build_span_maker, get_candidates +from spacy.ml.models.entity_linker import build_span_maker, get_candidates_v2 from spacy.pipeline import EntityLinker, TrainablePipe from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL from spacy.scorer import Scorer @@ -453,16 +453,16 @@ def test_candidate_generation(nlp): mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) # test the size of the relevant candidates - adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0] + adam_ent_cands = next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0] assert len(adam_ent_cands) == 1 assert ( - len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2 + len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2 ) assert ( - len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 + len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0 ) # default case sensitive assert ( - len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) + len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0]) == 0 ) diff --git a/website/docs/api/architectures.mdx b/website/docs/api/architectures.mdx index 956234ac0..f2cc61fb0 100644 --- a/website/docs/api/architectures.mdx +++ b/website/docs/api/architectures.mdx @@ -1255,6 +1255,15 @@ A function that reads an existing `KnowledgeBase` from file. | --------- | -------------------------------------------------------- | | `kb_path` | The location of the KB that was stored to file. ~~Path~~ | +### spacy.CandidateGenerator.v2 {id="CandidateGenerator-v2"} + +A function that takes as input a [`KnowledgeBase`](/api/kb) and a +`Iterator[SpanGroup]` object denoting a collection of named entities for +multiple [`Doc`](/api/doc), and returns an iterable of plausible +[`Candidate`](/api/kb/#candidate) objects per `Doc`. The default +`CandidateGenerator` uses the text of a mention to find its potential aliases in +the `KnowledgeBase`. Note that this function is case-dependent. + ### spacy.CandidateGenerator.v1 {id="CandidateGenerator"} A function that takes as input a [`KnowledgeBase`](/api/kb) and a diff --git a/website/docs/api/entitylinker.mdx b/website/docs/api/entitylinker.mdx index 2ae6e1f8c..0496fe592 100644 --- a/website/docs/api/entitylinker.mdx +++ b/website/docs/api/entitylinker.mdx @@ -47,7 +47,7 @@ architectures and their arguments and hyperparameters. > "incl_context": True, > "model": DEFAULT_NEL_MODEL, > "entity_vector_length": 64, -> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v1'}, +> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v2'}, > "threshold": None, > } > nlp.add_pipe("entity_linker", config=config)