move kb_loader to initialize for NEL instead of constructor

This commit is contained in:
svlandeg 2020-10-07 14:56:00 +02:00
parent bcaad28eda
commit 33c2d4af16

View File

@ -8,6 +8,7 @@ from thinc.api import set_dropout_rate
import warnings
from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc
from .pipe import Pipe, deserialize_config
from ..language import Language
@ -41,11 +42,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"],
default_config={
"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64},
"model": DEFAULT_NEL_MODEL,
"labels_discard": [],
"incl_prior": True,
"incl_context": True,
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
},
default_score_weights={
@ -58,11 +59,11 @@ def make_entity_linker(
nlp: Language,
name: str,
model: Model,
kb_loader: Callable[[Vocab], KnowledgeBase],
*,
labels_discard: Iterable[str],
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
):
"""Construct an EntityLinker component.
@ -70,19 +71,21 @@ def make_entity_linker(
model (Model[List[Doc], Floats2d]): A model that learns document vector
representations. Given a batch of Doc objects, it should return a single
array, with one row per item in the batch.
kb (KnowledgeBase): The knowledge-base to link entities to.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
"""
return EntityLinker(
nlp.vocab,
model,
name,
kb_loader=kb_loader,
labels_discard=labels_discard,
incl_prior=incl_prior,
incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates,
)
@ -101,10 +104,10 @@ class EntityLinker(Pipe):
model: Model,
name: str = "entity_linker",
*,
kb_loader: Callable[[Vocab], KnowledgeBase],
labels_discard: Iterable[str],
incl_prior: bool,
incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
) -> None:
"""Initialize an entity linker.
@ -113,10 +116,12 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
DOCS: https://nightly.spacy.io/api/entitylinker#init
"""
@ -127,15 +132,17 @@ class EntityLinker(Pipe):
"labels_discard": list(labels_discard),
"incl_prior": incl_prior,
"incl_context": incl_context,
"entity_vector_length": entity_vector_length,
}
self.kb = kb_loader(self.vocab)
self.get_candidates = get_candidates
self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account
self.n_sents = cfg.get("n_sents", 0)
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
self.kb = empty_kb(entity_vector_length)(self.vocab)
def _require_kb(self) -> None:
def validate_kb(self) -> None:
# Raise an error if the knowledge base is not initialized.
if len(self.kb) == 0:
raise ValueError(Errors.E139.format(name=self.name))
@ -145,6 +152,7 @@ class EntityLinker(Pipe):
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
kb_loader: Callable[[Vocab], KnowledgeBase] = None,
):
"""Initialize the pipe for training, using a representative set
of data examples.
@ -152,11 +160,17 @@ class EntityLinker(Pipe):
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
Note that providing this argument, will overwrite all data accumulated in the current KB.
Use this only when loading a KB as-such from file.
DOCS: https://nightly.spacy.io/api/entitylinker#initialize
"""
self._ensure_examples(get_examples)
self._require_kb()
if kb_loader is not None:
self.kb = kb_loader(self.vocab)
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
self.validate_kb()
nO = self.kb.entity_vector_length
doc_sample = []
vector_sample = []
@ -192,7 +206,7 @@ class EntityLinker(Pipe):
DOCS: https://nightly.spacy.io/api/entitylinker#update
"""
self._require_kb()
self.validate_kb()
if losses is None:
losses = {}
losses.setdefault(self.name, 0.0)
@ -303,7 +317,7 @@ class EntityLinker(Pipe):
DOCS: https://nightly.spacy.io/api/entitylinker#predict
"""
self._require_kb()
self.validate_kb()
entity_count = 0
final_kb_ids = []
if not docs: