move kb_loader to initialize for NEL instead of constructor

This commit is contained in:
svlandeg 2020-10-07 14:56:00 +02:00
parent bcaad28eda
commit 33c2d4af16

View File

@ -8,6 +8,7 @@ from thinc.api import set_dropout_rate
import warnings import warnings
from ..kb import KnowledgeBase, Candidate from ..kb import KnowledgeBase, Candidate
from ..ml import empty_kb
from ..tokens import Doc from ..tokens import Doc
from .pipe import Pipe, deserialize_config from .pipe import Pipe, deserialize_config
from ..language import Language from ..language import Language
@ -41,11 +42,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"], assigns=["token.ent_kb_id"],
default_config={ default_config={
"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64},
"model": DEFAULT_NEL_MODEL, "model": DEFAULT_NEL_MODEL,
"labels_discard": [], "labels_discard": [],
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"entity_vector_length": 64,
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"}, "get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
}, },
default_score_weights={ default_score_weights={
@ -58,11 +59,11 @@ def make_entity_linker(
nlp: Language, nlp: Language,
name: str, name: str,
model: Model, model: Model,
kb_loader: Callable[[Vocab], KnowledgeBase],
*, *,
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]], get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
): ):
"""Construct an EntityLinker component. """Construct an EntityLinker component.
@ -70,19 +71,21 @@ def make_entity_linker(
model (Model[List[Doc], Floats2d]): A model that learns document vector model (Model[List[Doc], Floats2d]): A model that learns document vector
representations. Given a batch of Doc objects, it should return a single representations. Given a batch of Doc objects, it should return a single
array, with one row per item in the batch. array, with one row per item in the batch.
kb (KnowledgeBase): The knowledge-base to link entities to.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
""" """
return EntityLinker( return EntityLinker(
nlp.vocab, nlp.vocab,
model, model,
name, name,
kb_loader=kb_loader,
labels_discard=labels_discard, labels_discard=labels_discard,
incl_prior=incl_prior, incl_prior=incl_prior,
incl_context=incl_context, incl_context=incl_context,
entity_vector_length=entity_vector_length,
get_candidates=get_candidates, get_candidates=get_candidates,
) )
@ -101,10 +104,10 @@ class EntityLinker(Pipe):
model: Model, model: Model,
name: str = "entity_linker", name: str = "entity_linker",
*, *,
kb_loader: Callable[[Vocab], KnowledgeBase],
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
entity_vector_length: int,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]], get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
) -> None: ) -> None:
"""Initialize an entity linker. """Initialize an entity linker.
@ -113,10 +116,12 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
entity_vector_length (int): Size of encoding vectors in the KB.
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
produces a list of candidates, given a certain knowledge base and a textual mention.
DOCS: https://nightly.spacy.io/api/entitylinker#init DOCS: https://nightly.spacy.io/api/entitylinker#init
""" """
@ -127,15 +132,17 @@ class EntityLinker(Pipe):
"labels_discard": list(labels_discard), "labels_discard": list(labels_discard),
"incl_prior": incl_prior, "incl_prior": incl_prior,
"incl_context": incl_context, "incl_context": incl_context,
"entity_vector_length": entity_vector_length,
} }
self.kb = kb_loader(self.vocab)
self.get_candidates = get_candidates self.get_candidates = get_candidates
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account # how many neightbour sentences to take into account
self.n_sents = cfg.get("n_sents", 0) self.n_sents = cfg.get("n_sents", 0)
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
self.kb = empty_kb(entity_vector_length)(self.vocab)
def _require_kb(self) -> None: def validate_kb(self) -> None:
# Raise an error if the knowledge base is not initialized. # Raise an error if the knowledge base is not initialized.
if len(self.kb) == 0: if len(self.kb) == 0:
raise ValueError(Errors.E139.format(name=self.name)) raise ValueError(Errors.E139.format(name=self.name))
@ -145,6 +152,7 @@ class EntityLinker(Pipe):
get_examples: Callable[[], Iterable[Example]], get_examples: Callable[[], Iterable[Example]],
*, *,
nlp: Optional[Language] = None, nlp: Optional[Language] = None,
kb_loader: Callable[[Vocab], KnowledgeBase] = None,
): ):
"""Initialize the pipe for training, using a representative set """Initialize the pipe for training, using a representative set
of data examples. of data examples.
@ -152,11 +160,17 @@ class EntityLinker(Pipe):
get_examples (Callable[[], Iterable[Example]]): Function that get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects. returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of. nlp (Language): The current nlp object the component is part of.
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
Note that providing this argument, will overwrite all data accumulated in the current KB.
Use this only when loading a KB as-such from file.
DOCS: https://nightly.spacy.io/api/entitylinker#initialize DOCS: https://nightly.spacy.io/api/entitylinker#initialize
""" """
self._ensure_examples(get_examples) self._ensure_examples(get_examples)
self._require_kb() if kb_loader is not None:
self.kb = kb_loader(self.vocab)
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
self.validate_kb()
nO = self.kb.entity_vector_length nO = self.kb.entity_vector_length
doc_sample = [] doc_sample = []
vector_sample = [] vector_sample = []
@ -192,7 +206,7 @@ class EntityLinker(Pipe):
DOCS: https://nightly.spacy.io/api/entitylinker#update DOCS: https://nightly.spacy.io/api/entitylinker#update
""" """
self._require_kb() self.validate_kb()
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
@ -303,7 +317,7 @@ class EntityLinker(Pipe):
DOCS: https://nightly.spacy.io/api/entitylinker#predict DOCS: https://nightly.spacy.io/api/entitylinker#predict
""" """
self._require_kb() self.validate_kb()
entity_count = 0 entity_count = 0
final_kb_ids = [] final_kb_ids = []
if not docs: if not docs: