Update factory functions for backwards compatibility.

This commit is contained in:
Raphael Mitsch 2024-02-19 10:19:01 +01:00
parent 5d1ecf1aa1
commit 2951c19472
6 changed files with 37 additions and 13 deletions

View File

@ -238,7 +238,7 @@ grad_factor = 1.0
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
get_candidates = {"@misc":"spacy.CandidateGenerator.v2"}
incl_context = true
incl_prior = true
@ -517,7 +517,7 @@ width = ${components.tok2vec.model.encode.width}
{% if "entity_linker" in components -%}
[components.entity_linker]
factory = "entity_linker"
get_candidates = {"@misc":"spacy.CandidateGenerator.v1"}
get_candidates = {"@misc":"spacy.CandidateGenerator.v2"}
incl_context = true
incl_prior = true

View File

@ -120,13 +120,28 @@ def empty_kb(
@registry.misc("spacy.CandidateGenerator.v1")
def create_get_candidates() -> Callable[
[KnowledgeBase, Iterator[SpanGroup]], Iterator[CandidatesForDocT]
]:
def create_get_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates
def get_candidates(
@registry.misc("spacy.CandidateGenerator.v2")
def create_get_candidates_v2() -> Callable[
[KnowledgeBase, Iterator[SpanGroup]], Iterator[CandidatesForDocT]
]:
return get_candidates_v2
def get_candidates(kb: KnowledgeBase, mention: Span) -> Iterable[Candidate]:
"""
Return candidate entities for the given mention from the KB.
kb (KnowledgeBase): Knowledge base to query.
mention (Span): Entity mention.
RETURNS (Iterable[Candidate]): Identified candidates for specified mention.
"""
return next(next(get_candidates_v2(kb, iter([SpanGroup(mention.doc, spans=[mention])])))[0])
def get_candidates_v2(
kb: KnowledgeBase, mentions: Iterator[SpanGroup]
) -> Iterator[Iterable[Iterable[Candidate]]]:
"""

View File

@ -65,7 +65,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"incl_prior": True,
"incl_context": True,
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates": {"@misc": "spacy.CandidateGenerator.v2"},
"overwrite": False,
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},

View File

@ -10,7 +10,7 @@ from spacy.compat import pickle
from spacy.kb import Candidate, InMemoryLookupKB, KnowledgeBase
from spacy.lang.en import English
from spacy.ml import load_kb
from spacy.ml.models.entity_linker import build_span_maker, get_candidates
from spacy.ml.models.entity_linker import build_span_maker, get_candidates_v2
from spacy.pipeline import EntityLinker, TrainablePipe
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.scorer import Scorer
@ -453,16 +453,16 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
adam_ent_cands = next(get_candidates(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0]
adam_ent_cands = next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[adam_ent])))[0]
assert len(adam_ent_cands) == 1
assert (
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[douglas_ent])))[0]) == 2
)
assert (
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[Adam_ent])))[0]) == 0
) # default case sensitive
assert (
len(next(get_candidates(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0])
len(next(get_candidates_v2(mykb, SpanGroup(doc=doc, spans=[shrubbery_ent])))[0])
== 0
)

View File

@ -1255,6 +1255,15 @@ A function that reads an existing `KnowledgeBase` from file.
| --------- | -------------------------------------------------------- |
| `kb_path` | The location of the KB that was stored to file. ~~Path~~ |
### spacy.CandidateGenerator.v2 {id="CandidateGenerator-v2"}
A function that takes as input a [`KnowledgeBase`](/api/kb) and a
`Iterator[SpanGroup]` object denoting a collection of named entities for
multiple [`Doc`](/api/doc), and returns an iterable of plausible
[`Candidate`](/api/kb/#candidate) objects per `Doc`. The default
`CandidateGenerator` uses the text of a mention to find its potential aliases in
the `KnowledgeBase`. Note that this function is case-dependent.
### spacy.CandidateGenerator.v1 {id="CandidateGenerator"}
A function that takes as input a [`KnowledgeBase`](/api/kb) and a

View File

@ -47,7 +47,7 @@ architectures and their arguments and hyperparameters.
> "incl_context": True,
> "model": DEFAULT_NEL_MODEL,
> "entity_vector_length": 64,
> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v1'},
> "get_candidates": {'@misc': 'spacy.CandidateGenerator.v2'},
> "threshold": None,
> }
> nlp.add_pipe("entity_linker", config=config)