mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
move kb_loader to initialize for NEL instead of constructor
This commit is contained in:
parent
bcaad28eda
commit
33c2d4af16
|
@ -8,6 +8,7 @@ from thinc.api import set_dropout_rate
|
|||
import warnings
|
||||
|
||||
from ..kb import KnowledgeBase, Candidate
|
||||
from ..ml import empty_kb
|
||||
from ..tokens import Doc
|
||||
from .pipe import Pipe, deserialize_config
|
||||
from ..language import Language
|
||||
|
@ -41,11 +42,11 @@ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|||
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
|
||||
assigns=["token.ent_kb_id"],
|
||||
default_config={
|
||||
"kb_loader": {"@misc": "spacy.EmptyKB.v1", "entity_vector_length": 64},
|
||||
"model": DEFAULT_NEL_MODEL,
|
||||
"labels_discard": [],
|
||||
"incl_prior": True,
|
||||
"incl_context": True,
|
||||
"entity_vector_length": 64,
|
||||
"get_candidates": {"@misc": "spacy.CandidateGenerator.v1"},
|
||||
},
|
||||
default_score_weights={
|
||||
|
@ -58,11 +59,11 @@ def make_entity_linker(
|
|||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase],
|
||||
*,
|
||||
labels_discard: Iterable[str],
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
|
||||
):
|
||||
"""Construct an EntityLinker component.
|
||||
|
@ -70,19 +71,21 @@ def make_entity_linker(
|
|||
model (Model[List[Doc], Floats2d]): A model that learns document vector
|
||||
representations. Given a batch of Doc objects, it should return a single
|
||||
array, with one row per item in the batch.
|
||||
kb (KnowledgeBase): The knowledge-base to link entities to.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
"""
|
||||
return EntityLinker(
|
||||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
kb_loader=kb_loader,
|
||||
labels_discard=labels_discard,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
entity_vector_length=entity_vector_length,
|
||||
get_candidates=get_candidates,
|
||||
)
|
||||
|
||||
|
@ -101,10 +104,10 @@ class EntityLinker(Pipe):
|
|||
model: Model,
|
||||
name: str = "entity_linker",
|
||||
*,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase],
|
||||
labels_discard: Iterable[str],
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
entity_vector_length: int,
|
||||
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
|
||||
) -> None:
|
||||
"""Initialize an entity linker.
|
||||
|
@ -113,10 +116,12 @@ class EntityLinker(Pipe):
|
|||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
entity_vector_length (int): Size of encoding vectors in the KB.
|
||||
get_candidates (Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]): Function that
|
||||
produces a list of candidates, given a certain knowledge base and a textual mention.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#init
|
||||
"""
|
||||
|
@ -127,15 +132,17 @@ class EntityLinker(Pipe):
|
|||
"labels_discard": list(labels_discard),
|
||||
"incl_prior": incl_prior,
|
||||
"incl_context": incl_context,
|
||||
"entity_vector_length": entity_vector_length,
|
||||
}
|
||||
self.kb = kb_loader(self.vocab)
|
||||
self.get_candidates = get_candidates
|
||||
self.cfg = dict(cfg)
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
# how many neightbour sentences to take into account
|
||||
self.n_sents = cfg.get("n_sents", 0)
|
||||
# create an empty KB by default. If you want to load a predefined one, specify it in 'initialize'.
|
||||
self.kb = empty_kb(entity_vector_length)(self.vocab)
|
||||
|
||||
def _require_kb(self) -> None:
|
||||
def validate_kb(self) -> None:
|
||||
# Raise an error if the knowledge base is not initialized.
|
||||
if len(self.kb) == 0:
|
||||
raise ValueError(Errors.E139.format(name=self.name))
|
||||
|
@ -145,6 +152,7 @@ class EntityLinker(Pipe):
|
|||
get_examples: Callable[[], Iterable[Example]],
|
||||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase] = None,
|
||||
):
|
||||
"""Initialize the pipe for training, using a representative set
|
||||
of data examples.
|
||||
|
@ -152,11 +160,17 @@ class EntityLinker(Pipe):
|
|||
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||
returns a representative sample of gold-standard Example objects.
|
||||
nlp (Language): The current nlp object the component is part of.
|
||||
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
|
||||
Note that providing this argument, will overwrite all data accumulated in the current KB.
|
||||
Use this only when loading a KB as-such from file.
|
||||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#initialize
|
||||
"""
|
||||
self._ensure_examples(get_examples)
|
||||
self._require_kb()
|
||||
if kb_loader is not None:
|
||||
self.kb = kb_loader(self.vocab)
|
||||
self.cfg["entity_vector_length"] = self.kb.entity_vector_length
|
||||
self.validate_kb()
|
||||
nO = self.kb.entity_vector_length
|
||||
doc_sample = []
|
||||
vector_sample = []
|
||||
|
@ -192,7 +206,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#update
|
||||
"""
|
||||
self._require_kb()
|
||||
self.validate_kb()
|
||||
if losses is None:
|
||||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
@ -303,7 +317,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
DOCS: https://nightly.spacy.io/api/entitylinker#predict
|
||||
"""
|
||||
self._require_kb()
|
||||
self.validate_kb()
|
||||
entity_count = 0
|
||||
final_kb_ids = []
|
||||
if not docs:
|
||||
|
|
Loading…
Reference in New Issue
Block a user