Add empty_kb() as config argument.

This commit is contained in:
Raphael Mitsch 2022-11-28 10:46:02 +01:00
parent b1d458eca7
commit 7e6888dcd4
3 changed files with 28 additions and 14 deletions

View File

@ -18,6 +18,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb
"""
_KBType = TypeVar("_KBType", bound=KnowledgeBase)
def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase."""
# Make sure abstract KB is not instantiated.
@ -107,16 +109,15 @@ cdef class KnowledgeBase:
Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
)
KBType = TypeVar("KBType", bound=KnowledgeBase)
@classmethod
def generate_from_disk(
cls: Type[KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> KBType:
cls: Type[_KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> _KBType:
"""
Factory method for generating KnowledgeBase instance from file.
Factory method for generating KnowledgeBase subclass instance from file.
path (Union[str, Path]): Target file path.
exclude (Iterable[str]): List of components to exclude.
return (KBType): Instance of KnowledgeBase generated from file.
return (_KBType): Instance of KnowledgeBase subclass generated from file.
"""
raise NotImplementedError(
Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__)

View File

@ -99,6 +99,14 @@ def empty_kb(
return empty_kb_factory
@registry.misc("spacy.EmptyKB.v2")
def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory
@registry.misc("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates

View File

@ -18,7 +18,6 @@ from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate
from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc, Span
from .pipe import deserialize_config
from .legacy.entity_linker import EntityLinker_v1
@ -64,6 +63,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"},
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True,
@ -91,6 +91,7 @@ def make_entity_linker(
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Iterator[Iterable[Iterable[Candidate]]],
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool,
scorer: Optional[Callable],
use_gold_ents: bool,
@ -115,6 +116,7 @@ def make_entity_linker(
Iterator[Iterable[Iterable[Candidate]]]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
@ -151,6 +153,7 @@ def make_entity_linker(
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
get_candidates_all=get_candidates_all,
generate_empty_kb=generate_empty_kb,
overwrite=overwrite,
scorer=scorer,
use_gold_ents=use_gold_ents,
@ -192,6 +195,7 @@ class EntityLinker(TrainablePipe):
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Iterator[Iterable[Iterable[Candidate]]],
],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool,
@ -215,14 +219,15 @@ class EntityLinker(TrainablePipe):
Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]],
Iterator[Iterable[Iterable[Candidate]]]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions.
]): Function that produces a list of candidates per document, given a certain knowledge base and several
textual documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations.
candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator
yielding entities per document (candidate generator callable is called only once in this case). If False,
the candidate generator is called once per entity.
candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a
generator yielding entities per document (candidate generator callable is called only once in this case). If
False, the candidate generator is called once per entity.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
DOCS: https://spacy.io/api/entitylinker#init
@ -241,16 +246,16 @@ class EntityLinker(TrainablePipe):
self.model = model
self.name = name
self.labels_discard = list(labels_discard)
# how many neighbour sentences to take into account
self.n_sents = n_sents
self.incl_prior = incl_prior
self.incl_context = incl_context
self.get_candidates = get_candidates
self.get_candidates_all = get_candidates_all
self.generate_empty_kb = generate_empty_kb
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False)
# how many neighbour sentences to take into account
# create an empty KB by default
self.kb = empty_kb(entity_vector_length)(self.vocab)
self.kb = generate_empty_kb(self.vocab, entity_vector_length)
self.scorer = scorer
self.use_gold_ents = use_gold_ents
self.candidates_doc_mode = candidates_doc_mode