mirror of
https://github.com/explosion/spaCy.git
synced 2025-11-10 21:07:53 +03:00
In order to support Python 3.13, we had to migrate to Cython 3.0. This caused some tricky interaction with our Pydantic usage, because Cython 3 uses the from __future__ import annotations semantics, which causes type annotations to be saved as strings. The end result is that we can't have Language.factory decorated functions in Cython modules anymore, as the Language.factory decorator expects to inspect the signature of the functions and build a Pydantic model. If the function is implemented in Cython, an error is raised because the type is not resolved. To address this I've moved the factory functions into a new module, spacy.pipeline.factories. I've added __getattr__ importlib hooks to the previous locations, in case anyone was importing these functions directly. The change should have no backwards compatibility implications. Along the way I've also refactored the registration of functions for the config. Previously these ran as import-time side-effects, using the registry decorator. I've created instead a new module spacy.registrations. When the registry is accessed it calls a function ensure_populated(), which cases the registrations to occur. I've made a similar change to the Language.factory registrations in the new spacy.pipeline.factories module. I want to remove these import-time side-effects so that we can speed up the loading time of the library, which can be especially painful on the CLI. I also find that I'm often working to track down the implementations of functions referenced by strings in the config. Having the registrations all happen in one place will make this easier. With these changes I've fortunately avoided the need to migrate to Pydantic v2 properly --- we're still using the v1 compatibility shim. We might not be able to hold out forever though: Pydantic (reasonably) aren't actively supporting the v1 shims. I put a lot of work into v2 migration when investigating the 3.13 support, and it's definitely challenging. In any case, it's a relief that we don't have to do the v2 migration at the same time as the Cython 3.0/Python 3.13 support.
582 lines
24 KiB
Python
582 lines
24 KiB
Python
import importlib
|
|
import random
|
|
import sys
|
|
from itertools import islice
|
|
from pathlib import Path
|
|
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
|
|
|
import srsly
|
|
from thinc.api import Config, CosineDistance, Model, Optimizer, set_dropout_rate
|
|
from thinc.types import Floats2d
|
|
|
|
from .. import util
|
|
from ..errors import Errors
|
|
from ..kb import Candidate, KnowledgeBase
|
|
from ..language import Language
|
|
from ..scorer import Scorer
|
|
from ..tokens import Doc, Span
|
|
from ..training import Example, validate_examples, validate_get_examples
|
|
from ..util import SimpleFrozenList, registry
|
|
from ..vocab import Vocab
|
|
from .legacy.entity_linker import EntityLinker_v1
|
|
from .pipe import deserialize_config
|
|
from .trainable_pipe import TrainablePipe
|
|
|
|
# See #9050
|
|
BACKWARD_OVERWRITE = True
|
|
|
|
default_model_config = """
|
|
[model]
|
|
@architectures = "spacy.EntityLinker.v2"
|
|
|
|
[model.tok2vec]
|
|
@architectures = "spacy.HashEmbedCNN.v2"
|
|
pretrained_vectors = null
|
|
width = 96
|
|
depth = 2
|
|
embed_size = 2000
|
|
window_size = 1
|
|
maxout_pieces = 3
|
|
subword_features = true
|
|
"""
|
|
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
|
|
|
|
|
def entity_linker_score(examples, **kwargs):
|
|
return Scorer.score_links(examples, negative_labels=[EntityLinker.NIL], **kwargs)
|
|
|
|
|
|
def make_entity_linker_scorer():
|
|
return entity_linker_score
|
|
|
|
|
|
class EntityLinker(TrainablePipe):
|
|
"""Pipeline component for named entity linking.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker
|
|
"""
|
|
|
|
NIL = "NIL" # string used to refer to a non-existing link
|
|
|
|
def __init__(
|
|
self,
|
|
vocab: Vocab,
|
|
model: Model,
|
|
name: str = "entity_linker",
|
|
*,
|
|
labels_discard: Iterable[str],
|
|
n_sents: int,
|
|
incl_prior: bool,
|
|
incl_context: bool,
|
|
entity_vector_length: int,
|
|
get_candidates: Callable[[KnowledgeBase, Span], Iterable[Candidate]],
|
|
get_candidates_batch: Callable[
|
|
[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]
|
|
],
|
|
generate_empty_kb: Callable[[Vocab, int], KnowledgeBase],
|
|
overwrite: bool = BACKWARD_OVERWRITE,
|
|
scorer: Optional[Callable] = entity_linker_score,
|
|
use_gold_ents: bool,
|
|
candidates_batch_size: int,
|
|
threshold: Optional[float] = None,
|
|
) -> None:
|
|
"""Initialize an entity linker.
|
|
|
|
vocab (Vocab): The shared vocabulary.
|
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
|
name (str): The component instance name, used to add entries to the
|
|
losses during training.
|
|
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
|
n_sents (int): The number of neighbouring sentences to take into account.
|
|
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
|
incl_context (bool): Whether or not to include the local context in the model.
|
|
entity_vector_length (int): Size of encoding vectors in the KB.
|
|
get_candidates (Callable[[KnowledgeBase, Span], Iterable[Candidate]]): Function that
|
|
produces a list of candidates, given a certain knowledge base and a textual mention.
|
|
get_candidates_batch (
|
|
Callable[[KnowledgeBase, Iterable[Span]], Iterable[Iterable[Candidate]]],
|
|
Iterable[Candidate]]
|
|
): Function that produces a list of candidates, given a certain knowledge base and several textual mentions.
|
|
generate_empty_kb (Callable[[Vocab, int], KnowledgeBase]): Callable returning empty KnowledgeBase.
|
|
scorer (Optional[Callable]): The scoring method. Defaults to Scorer.score_links.
|
|
use_gold_ents (bool): Whether to copy entities from gold docs or not. If false, another
|
|
component must provide entity annotations.
|
|
candidates_batch_size (int): Size of batches for entity candidate generation.
|
|
threshold (Optional[float]): Confidence threshold for entity predictions. If confidence is below the
|
|
threshold, prediction is discarded. If None, predictions are not filtered by any threshold.
|
|
DOCS: https://spacy.io/api/entitylinker#init
|
|
"""
|
|
|
|
if threshold is not None and not (0 <= threshold <= 1):
|
|
raise ValueError(
|
|
Errors.E1043.format(
|
|
range_start=0,
|
|
range_end=1,
|
|
value=threshold,
|
|
)
|
|
)
|
|
|
|
self.vocab = vocab
|
|
self.model = model
|
|
self.name = name
|
|
self.labels_discard = list(labels_discard)
|
|
# how many neighbour sentences to take into account
|
|
self.n_sents = n_sents
|
|
self.incl_prior = incl_prior
|
|
self.incl_context = incl_context
|
|
self.get_candidates = get_candidates
|
|
self.get_candidates_batch = get_candidates_batch
|
|
self.cfg: Dict[str, Any] = {"overwrite": overwrite}
|
|
self.distance = CosineDistance(normalize=False)
|
|
self.kb = generate_empty_kb(self.vocab, entity_vector_length)
|
|
self.use_gold_ents = use_gold_ents
|
|
self.candidates_batch_size = candidates_batch_size
|
|
self.threshold = threshold
|
|
|
|
if candidates_batch_size < 1:
|
|
raise ValueError(Errors.E1044)
|
|
|
|
def _score_with_ents_set(examples: Iterable[Example], **kwargs):
|
|
# Because of how spaCy works, we can't just score immediately, because Language.evaluate
|
|
# calls pipe() on the predicted docs, which won't have entities if there is no NER in the pipeline.
|
|
if not scorer:
|
|
return scorer
|
|
if not self.use_gold_ents:
|
|
return scorer(examples, **kwargs)
|
|
else:
|
|
examples = self._ensure_ents(examples)
|
|
docs = self.pipe(
|
|
(eg.predicted for eg in examples),
|
|
)
|
|
for eg, doc in zip(examples, docs):
|
|
eg.predicted = doc
|
|
return scorer(examples, **kwargs)
|
|
|
|
self.scorer = _score_with_ents_set
|
|
|
|
def _ensure_ents(self, examples: Iterable[Example]) -> Iterable[Example]:
|
|
"""If use_gold_ents is true, set the gold entities to (a copy of) eg.predicted."""
|
|
if not self.use_gold_ents:
|
|
return examples
|
|
|
|
new_examples = []
|
|
for eg in examples:
|
|
ents, _ = eg.get_aligned_ents_and_ner()
|
|
new_eg = eg.copy()
|
|
new_eg.predicted.ents = ents
|
|
new_examples.append(new_eg)
|
|
return new_examples
|
|
|
|
def set_kb(self, kb_loader: Callable[[Vocab], KnowledgeBase]):
|
|
"""Define the KB of this pipe by providing a function that will
|
|
create it using this object's vocab."""
|
|
if not callable(kb_loader):
|
|
raise ValueError(Errors.E885.format(arg_type=type(kb_loader)))
|
|
|
|
self.kb = kb_loader(self.vocab) # type: ignore
|
|
|
|
def validate_kb(self) -> None:
|
|
# Raise an error if the knowledge base is not initialized.
|
|
if self.kb is None:
|
|
raise ValueError(Errors.E1018.format(name=self.name))
|
|
if hasattr(self.kb, "is_empty") and self.kb.is_empty():
|
|
raise ValueError(Errors.E139.format(name=self.name))
|
|
|
|
def initialize(
|
|
self,
|
|
get_examples: Callable[[], Iterable[Example]],
|
|
*,
|
|
nlp: Optional[Language] = None,
|
|
kb_loader: Optional[Callable[[Vocab], KnowledgeBase]] = None,
|
|
):
|
|
"""Initialize the pipe for training, using a representative set
|
|
of data examples.
|
|
|
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
|
returns a representative sample of gold-standard Example objects.
|
|
nlp (Language): The current nlp object the component is part of.
|
|
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab
|
|
instance. Note that providing this argument will overwrite all data accumulated in the current KB.
|
|
Use this only when loading a KB as-such from file.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#initialize
|
|
"""
|
|
validate_get_examples(get_examples, "EntityLinker.initialize")
|
|
if kb_loader is not None:
|
|
self.set_kb(kb_loader)
|
|
self.validate_kb()
|
|
nO = self.kb.entity_vector_length
|
|
doc_sample = []
|
|
vector_sample = []
|
|
examples = self._ensure_ents(islice(get_examples(), 10))
|
|
for eg in examples:
|
|
doc = eg.x
|
|
doc_sample.append(doc)
|
|
vector_sample.append(self.model.ops.alloc1f(nO))
|
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
|
assert len(vector_sample) > 0, Errors.E923.format(name=self.name)
|
|
|
|
# XXX In order for size estimation to work, there has to be at least
|
|
# one entity. It's not used for training so it doesn't have to be real,
|
|
# so we add a fake one if none are present.
|
|
# We can't use Doc.has_annotation here because it can be True for docs
|
|
# that have been through an NER component but got no entities.
|
|
has_annotations = any([doc.ents for doc in doc_sample])
|
|
if not has_annotations:
|
|
doc = doc_sample[0]
|
|
ent = doc[0:1]
|
|
ent.label_ = "XXX"
|
|
doc.ents = (ent,)
|
|
|
|
self.model.initialize(
|
|
X=doc_sample, Y=self.model.ops.asarray(vector_sample, dtype="float32")
|
|
)
|
|
|
|
if not has_annotations:
|
|
# Clean up dummy annotation
|
|
doc.ents = []
|
|
|
|
def batch_has_learnable_example(self, examples):
|
|
"""Check if a batch contains a learnable example.
|
|
|
|
If one isn't present, then the update step needs to be skipped.
|
|
"""
|
|
|
|
for eg in examples:
|
|
for ent in eg.predicted.ents:
|
|
candidates = list(self.get_candidates(self.kb, ent))
|
|
if candidates:
|
|
return True
|
|
|
|
return False
|
|
|
|
def update(
|
|
self,
|
|
examples: Iterable[Example],
|
|
*,
|
|
drop: float = 0.0,
|
|
sgd: Optional[Optimizer] = None,
|
|
losses: Optional[Dict[str, float]] = None,
|
|
) -> Dict[str, float]:
|
|
"""Learn from a batch of documents and gold-standard information,
|
|
updating the pipe's model. Delegates to predict and get_loss.
|
|
|
|
examples (Iterable[Example]): A batch of Example objects.
|
|
drop (float): The dropout rate.
|
|
sgd (thinc.api.Optimizer): The optimizer.
|
|
losses (Dict[str, float]): Optional record of the loss during training.
|
|
Updated using the component name as the key.
|
|
RETURNS (Dict[str, float]): The updated losses dictionary.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#update
|
|
"""
|
|
self.validate_kb()
|
|
if losses is None:
|
|
losses = {}
|
|
losses.setdefault(self.name, 0.0)
|
|
if not examples:
|
|
return losses
|
|
examples = self._ensure_ents(examples)
|
|
validate_examples(examples, "EntityLinker.update")
|
|
|
|
# make sure we have something to learn from, if not, short-circuit
|
|
if not self.batch_has_learnable_example(examples):
|
|
return losses
|
|
|
|
set_dropout_rate(self.model, drop)
|
|
docs = [eg.predicted for eg in examples]
|
|
sentence_encodings, bp_context = self.model.begin_update(docs)
|
|
|
|
loss, d_scores = self.get_loss(
|
|
sentence_encodings=sentence_encodings, examples=examples
|
|
)
|
|
bp_context(d_scores)
|
|
if sgd is not None:
|
|
self.finish_update(sgd)
|
|
losses[self.name] += loss
|
|
|
|
return losses
|
|
|
|
def get_loss(self, examples: Iterable[Example], sentence_encodings: Floats2d):
|
|
validate_examples(examples, "EntityLinker.get_loss")
|
|
entity_encodings = []
|
|
# We assume that get_loss is called with gold ents set in the examples if need be
|
|
eidx = 0 # indices in gold entities to keep
|
|
keep_ents = [] # indices in sentence_encodings to keep
|
|
|
|
for eg in examples:
|
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
|
|
|
for ent in eg.get_matching_ents():
|
|
kb_id = kb_ids[ent.start]
|
|
if kb_id:
|
|
entity_encoding = self.kb.get_vector(kb_id)
|
|
entity_encodings.append(entity_encoding)
|
|
keep_ents.append(eidx)
|
|
|
|
eidx += 1
|
|
entity_encodings = self.model.ops.asarray2f(entity_encodings, dtype="float32")
|
|
selected_encodings = sentence_encodings[keep_ents]
|
|
|
|
# if there are no matches, short circuit
|
|
if not keep_ents:
|
|
out = self.model.ops.alloc2f(*sentence_encodings.shape)
|
|
return 0, out
|
|
|
|
if selected_encodings.shape != entity_encodings.shape:
|
|
err = Errors.E147.format(
|
|
method="get_loss", msg="gold entities do not match up"
|
|
)
|
|
raise RuntimeError(err)
|
|
gradients = self.distance.get_grad(selected_encodings, entity_encodings)
|
|
# to match the input size, we need to give a zero gradient for items not in the kb
|
|
out = self.model.ops.alloc2f(*sentence_encodings.shape)
|
|
out[keep_ents] = gradients
|
|
|
|
loss = self.distance.get_loss(selected_encodings, entity_encodings)
|
|
loss = loss / len(entity_encodings)
|
|
return float(loss), out
|
|
|
|
def predict(self, docs: Iterable[Doc]) -> List[str]:
|
|
"""Apply the pipeline's model to a batch of docs, without modifying them.
|
|
Returns the KB IDs for each entity in each doc, including NIL if there is
|
|
no prediction.
|
|
|
|
docs (Iterable[Doc]): The documents to predict.
|
|
RETURNS (List[str]): The models prediction for each document.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#predict
|
|
"""
|
|
self.validate_kb()
|
|
entity_count = 0
|
|
final_kb_ids: List[str] = []
|
|
xp = self.model.ops.xp
|
|
if not docs:
|
|
return final_kb_ids
|
|
if isinstance(docs, Doc):
|
|
docs = [docs]
|
|
for i, doc in enumerate(docs):
|
|
if len(doc) == 0:
|
|
continue
|
|
sentences = [s for s in doc.sents]
|
|
|
|
# Loop over entities in batches.
|
|
for ent_idx in range(0, len(doc.ents), self.candidates_batch_size):
|
|
ent_batch = doc.ents[ent_idx : ent_idx + self.candidates_batch_size]
|
|
|
|
# Look up candidate entities.
|
|
valid_ent_idx = [
|
|
idx
|
|
for idx in range(len(ent_batch))
|
|
if ent_batch[idx].label_ not in self.labels_discard
|
|
]
|
|
|
|
batch_candidates = list(
|
|
self.get_candidates_batch(
|
|
self.kb, [ent_batch[idx] for idx in valid_ent_idx]
|
|
)
|
|
if self.candidates_batch_size > 1
|
|
else [
|
|
self.get_candidates(self.kb, ent_batch[idx])
|
|
for idx in valid_ent_idx
|
|
]
|
|
)
|
|
|
|
# Looping through each entity in batch (TODO: rewrite)
|
|
for j, ent in enumerate(ent_batch):
|
|
assert hasattr(ent, "sents")
|
|
sents = list(ent.sents)
|
|
sent_indices = (
|
|
sentences.index(sents[0]),
|
|
sentences.index(sents[-1]),
|
|
)
|
|
assert sent_indices[1] >= sent_indices[0] >= 0
|
|
|
|
if self.incl_context:
|
|
# get n_neighbour sentences, clipped to the length of the document
|
|
start_sentence = max(0, sent_indices[0] - self.n_sents)
|
|
end_sentence = min(
|
|
len(sentences) - 1, sent_indices[1] + self.n_sents
|
|
)
|
|
start_token = sentences[start_sentence].start
|
|
end_token = sentences[end_sentence].end
|
|
sent_doc = doc[start_token:end_token].as_doc()
|
|
|
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
|
sentence_encoding_t = sentence_encoding.T
|
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
|
entity_count += 1
|
|
if ent.label_ in self.labels_discard:
|
|
# ignoring this entity - setting to NIL
|
|
final_kb_ids.append(self.NIL)
|
|
else:
|
|
candidates = list(batch_candidates[j])
|
|
if not candidates:
|
|
# no prediction possible for this entity - setting to NIL
|
|
final_kb_ids.append(self.NIL)
|
|
elif len(candidates) == 1 and self.threshold is None:
|
|
# shortcut for efficiency reasons: take the 1 candidate
|
|
final_kb_ids.append(candidates[0].entity_)
|
|
else:
|
|
random.shuffle(candidates)
|
|
# set all prior probabilities to 0 if incl_prior=False
|
|
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
|
if not self.incl_prior:
|
|
prior_probs = xp.asarray([0.0 for _ in candidates])
|
|
scores = prior_probs
|
|
# add in similarity from the context
|
|
if self.incl_context:
|
|
entity_encodings = xp.asarray(
|
|
[c.entity_vector for c in candidates]
|
|
)
|
|
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
|
if len(entity_encodings) != len(prior_probs):
|
|
raise RuntimeError(
|
|
Errors.E147.format(
|
|
method="predict",
|
|
msg="vectors not of equal length",
|
|
)
|
|
)
|
|
# cosine similarity
|
|
sims = xp.dot(entity_encodings, sentence_encoding_t) / (
|
|
sentence_norm * entity_norm
|
|
)
|
|
if sims.shape != prior_probs.shape:
|
|
raise ValueError(Errors.E161)
|
|
scores = prior_probs + sims - (prior_probs * sims)
|
|
final_kb_ids.append(
|
|
candidates[scores.argmax().item()].entity_
|
|
if self.threshold is None
|
|
or scores.max() >= self.threshold
|
|
else EntityLinker.NIL
|
|
)
|
|
|
|
if not (len(final_kb_ids) == entity_count):
|
|
err = Errors.E147.format(
|
|
method="predict", msg="result variables not of equal length"
|
|
)
|
|
raise RuntimeError(err)
|
|
return final_kb_ids
|
|
|
|
def set_annotations(self, docs: Iterable[Doc], kb_ids: List[str]) -> None:
|
|
"""Modify a batch of documents, using pre-computed scores.
|
|
|
|
docs (Iterable[Doc]): The documents to modify.
|
|
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#set_annotations
|
|
"""
|
|
count_ents = len([ent for doc in docs for ent in doc.ents])
|
|
if count_ents != len(kb_ids):
|
|
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
|
i = 0
|
|
overwrite = self.cfg["overwrite"]
|
|
for doc in docs:
|
|
for ent in doc.ents:
|
|
kb_id = kb_ids[i]
|
|
i += 1
|
|
for token in ent:
|
|
if token.ent_kb_id == 0 or overwrite:
|
|
token.ent_kb_id_ = kb_id
|
|
|
|
def to_bytes(self, *, exclude=tuple()):
|
|
"""Serialize the pipe to a bytestring.
|
|
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
RETURNS (bytes): The serialized object.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#to_bytes
|
|
"""
|
|
self._validate_serialization_attrs()
|
|
serialize = {}
|
|
if hasattr(self, "cfg") and self.cfg is not None:
|
|
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
|
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
|
serialize["kb"] = self.kb.to_bytes
|
|
serialize["model"] = self.model.to_bytes
|
|
return util.to_bytes(serialize, exclude)
|
|
|
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
|
"""Load the pipe from a bytestring.
|
|
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
RETURNS (TrainablePipe): The loaded object.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#from_bytes
|
|
"""
|
|
self._validate_serialization_attrs()
|
|
|
|
def load_model(b):
|
|
try:
|
|
self.model.from_bytes(b)
|
|
except AttributeError:
|
|
raise ValueError(Errors.E149) from None
|
|
|
|
deserialize = {}
|
|
if hasattr(self, "cfg") and self.cfg is not None:
|
|
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
|
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
|
|
deserialize["kb"] = lambda b: self.kb.from_bytes(b)
|
|
deserialize["model"] = load_model
|
|
util.from_bytes(bytes_data, deserialize, exclude)
|
|
return self
|
|
|
|
def to_disk(
|
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
|
) -> None:
|
|
"""Serialize the pipe to disk.
|
|
|
|
path (str / Path): Path to a directory.
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#to_disk
|
|
"""
|
|
serialize = {}
|
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
|
serialize["kb"] = lambda p: self.kb.to_disk(p)
|
|
serialize["model"] = lambda p: self.model.to_disk(p)
|
|
util.to_disk(path, serialize, exclude)
|
|
|
|
def from_disk(
|
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
|
) -> "EntityLinker":
|
|
"""Load the pipe from disk. Modifies the object in place and returns it.
|
|
|
|
path (str / Path): Path to a directory.
|
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
|
RETURNS (EntityLinker): The modified EntityLinker object.
|
|
|
|
DOCS: https://spacy.io/api/entitylinker#from_disk
|
|
"""
|
|
|
|
def load_model(p):
|
|
try:
|
|
with p.open("rb") as infile:
|
|
self.model.from_bytes(infile.read())
|
|
except AttributeError:
|
|
raise ValueError(Errors.E149) from None
|
|
|
|
deserialize: Dict[str, Callable[[Any], Any]] = {}
|
|
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
|
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
|
|
deserialize["kb"] = lambda p: self.kb.from_disk(p)
|
|
deserialize["model"] = load_model
|
|
util.from_disk(path, deserialize, exclude)
|
|
return self
|
|
|
|
def rehearse(self, examples, *, sgd=None, losses=None, **config):
|
|
raise NotImplementedError
|
|
|
|
def add_label(self, label):
|
|
raise NotImplementedError
|
|
|
|
|
|
# Setup backwards compatibility hook for factories
|
|
def __getattr__(name):
|
|
if name == "make_entity_linker":
|
|
module = importlib.import_module("spacy.pipeline.factories")
|
|
return module.make_entity_linker
|
|
raise AttributeError(f"module {__name__} has no attribute {name}")
|