Add empty_kb() as config argument.

This commit is contained in:
Raphael Mitsch 2022-11-28 10:46:02 +01:00
parent b1d458eca7
commit 7e6888dcd4
3 changed files with 28 additions and 14 deletions

View File

@ -18,6 +18,8 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
_KBType = TypeVar("_KBType", bound=KnowledgeBase)
def __init__(self, vocab: Vocab, entity_vector_length: int): def __init__(self, vocab: Vocab, entity_vector_length: int):
"""Create a KnowledgeBase.""" """Create a KnowledgeBase."""
# Make sure abstract KB is not instantiated. # Make sure abstract KB is not instantiated.
@ -107,16 +109,15 @@ cdef class KnowledgeBase:
Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__) Errors.E1044.format(parent="KnowledgeBase", method="from_disk", name=self.__name__)
) )
KBType = TypeVar("KBType", bound=KnowledgeBase)
@classmethod @classmethod
def generate_from_disk( def generate_from_disk(
cls: Type[KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList() cls: Type[_KBType], path: Union[str, Path], exclude: Iterable[str] = SimpleFrozenList()
) -> KBType: ) -> _KBType:
""" """
Factory method for generating KnowledgeBase instance from file. Factory method for generating KnowledgeBase subclass instance from file.
path (Union[str, Path]): Target file path. path (Union[str, Path]): Target file path.
exclude (Iterable[str]): List of components to exclude. exclude (Iterable[str]): List of components to exclude.
return (KBType): Instance of KnowledgeBase generated from file. return (_KBType): Instance of KnowledgeBase subclass generated from file.
""" """
raise NotImplementedError( raise NotImplementedError(
Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__) Errors.E1044.format(parent="KnowledgeBase", method="generate_from_disk", name=cls.__name__)

View File

@ -99,6 +99,14 @@ def empty_kb(
return empty_kb_factory return empty_kb_factory
@registry.misc("spacy.EmptyKB.v2")
def empty_kb_for_config() -> Callable[[Vocab, int], KnowledgeBase]:
def empty_kb_factory(vocab: Vocab, entity_vector_length: int):
return InMemoryLookupKB(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory
@registry.misc("spacy.CandidateGenerator.v1") @registry.misc("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]: def create_candidates() -> Callable[[KnowledgeBase, Span], Iterable[Candidate]]:
return get_candidates return get_candidates

View File

@ -18,7 +18,6 @@ from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate
from ..kb import KnowledgeBase, Candidate from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc, Span from ..tokens import Doc, Span
from .pipe import deserialize_config from .pipe import deserialize_config
from .legacy.entity_linker import EntityLinker_v1 from .legacy.entity_linker import EntityLinker_v1
@ -64,6 +63,7 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
"entity_vector_length": 64, "entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
"get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"}, "get_candidates_all": {"@misc": "spacy.CandidateAllGenerator.v1"},
"generate_empty_kb": {"@misc": "spacy.EmptyKB.v2"},
"overwrite": True, "overwrite": True,
"scorer": {"@scorers": "spacy.entity_linker_scorer.v1"}, "scorer": {"@scorers": "spacy.entity_linker_scorer.v1"},
"use_gold_ents": True, "use_gold_ents": True,
@ -91,6 +91,7 @@ def make_entity_linker(
[KnowledgeBase, Generator[Iterable[Span], None, None]], [KnowledgeBase, Generator[Iterable[Span], None, None]],
Iterator[Iterable[Iterable[Candidate]]], Iterator[Iterable[Iterable[Candidate]]],
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool, overwrite: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
use_gold_ents: bool, use_gold_ents: bool,
@ -115,6 +116,7 @@ def make_entity_linker(
Iterator[Iterable[Iterable[Candidate]]] Iterator[Iterable[Iterable[Candidate]]]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual ]): Function that produces a list of candidates per document, given a certain knowledge base and several textual
documents with textual mentions. documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. scorer (Optional[Callable]): The scoring method.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
@ -151,6 +153,7 @@ def make_entity_linker(
entity_vector_length=entity_vector_length, entity_vector_length=entity_vector_length,
get_candidates=get_candidates, get_candidates=get_candidates,
get_candidates_all=get_candidates_all, get_candidates_all=get_candidates_all,
generate_empty_kb=generate_empty_kb,
overwrite=overwrite, overwrite=overwrite,
scorer=scorer, scorer=scorer,
use_gold_ents=use_gold_ents, use_gold_ents=use_gold_ents,
@ -192,6 +195,7 @@ class EntityLinker(TrainablePipe):
[KnowledgeBase, Generator[Iterable[Span], None, None]], [KnowledgeBase, Generator[Iterable[Span], None, None]],
Iterator[Iterable[Iterable[Candidate]]], Iterator[Iterable[Iterable[Candidate]]],
], ],
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
overwrite: bool = BACKWARD_OVERWRITE, overwrite: bool = BACKWARD_OVERWRITE,
scorer: Optional[Callable] = entity_linker_score, scorer: Optional[Callable] = entity_linker_score,
use_gold_ents: bool, use_gold_ents: bool,
@ -215,14 +219,15 @@ class EntityLinker(TrainablePipe):
Callable[ Callable[
[KnowledgeBase, Generator[Iterable[Span], None, None]], [KnowledgeBase, Generator[Iterable[Span], None, None]],
Iterator[Iterable[Iterable[Candidate]]] Iterator[Iterable[Iterable[Candidate]]]
]): Function that produces a list of candidates per document, given a certain knowledge base and several textual ]): Function that produces a list of candidates per document, given a certain knowledge base and several
documents with textual mentions. textual documents with textual mentions.
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links. scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
component must provide entity annotations. component must provide entity annotations.
candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a generator candidates_doc_mode (bool): Whether or not to operate candidate generation in doc mode, i.e. to provide a
yielding entities per document (candidate generator callable is called only once in this case). If False, generator yielding entities per document (candidate generator callable is called only once in this case). If
the candidate generator is called once per entity. False, the candidate generator is called once per entity.
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
threshold, prediction is discarded. If None, predictions are not filtered by any threshold. threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
DOCS: https://spacy.io/api/entitylinker#init DOCS: https://spacy.io/api/entitylinker#init
@ -241,16 +246,16 @@ class EntityLinker(TrainablePipe):
self.model = model self.model = model
self.name = name self.name = name
self.labels_discard = list(labels_discard) self.labels_discard = list(labels_discard)
# how many neighbour sentences to take into account
self.n_sents = n_sents self.n_sents = n_sents
self.incl_prior = incl_prior self.incl_prior = incl_prior
self.incl_context = incl_context self.incl_context = incl_context
self.get_candidates = get_candidates self.get_candidates = get_candidates
self.get_candidates_all = get_candidates_all self.get_candidates_all = get_candidates_all
self.generate_empty_kb = generate_empty_kb
self.cfg: Dict[str, Any] = {"overwrite": overwrite} self.cfg: Dict[str, Any] = {"overwrite": overwrite}
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neighbour sentences to take into account self.kb = generate_empty_kb(self.vocab, entity_vector_length)
# create an empty KB by default
self.kb = empty_kb(entity_vector_length)(self.vocab)
self.scorer = scorer self.scorer = scorer
self.use_gold_ents = use_gold_ents self.use_gold_ents = use_gold_ents
self.candidates_doc_mode = candidates_doc_mode self.candidates_doc_mode = candidates_doc_mode