mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Define candidate generator in EL config (#5876)
* candidate generator as separate part of EL config * update comment * ent instead of str as input for candidate generation * Span instead of str: correct type indication * fix types * unit test to create new candidate generator * fix replace_pipe argument passing * move error message, general cleanup * add vocab back to KB constructor * provide KB as callable from Vocab arg * rename to kb_loader, fix KB serialization as part of the EL pipe * fix typo * reformatting * cleanup * fix comment * fix wrongly duplicated code from merge conflict * rename dump to to_disk * from_disk instead of load_bulk * update test after recent removal of set_morphology in tagger * remove old doc
This commit is contained in:
parent
688e77562b
commit
358cbb21e3
|
@ -15,7 +15,8 @@ import spacy.util
|
|||
from bin.ud import conll17_ud_eval
|
||||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import Example
|
||||
from spacy.util import compounding, minibatch, minibatch_by_words
|
||||
from spacy.util import compounding, minibatch
|
||||
from spacy.gold.batchers import minibatch_by_words
|
||||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from spacy.matcher import Matcher
|
||||
from spacy import displacy
|
||||
|
|
|
@ -48,8 +48,7 @@ def main(model, output_dir=None):
|
|||
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
|
||||
# For simplicity, we'll just use the original vector dimension here instead.
|
||||
vectors_dim = nlp.vocab.vectors.shape[1]
|
||||
kb = KnowledgeBase(entity_vector_length=vectors_dim)
|
||||
kb.initialize(nlp.vocab)
|
||||
kb = KnowledgeBase(nlp.vocab, entity_vector_length=vectors_dim)
|
||||
|
||||
# set up the data
|
||||
entity_ids = []
|
||||
|
@ -81,7 +80,7 @@ def main(model, output_dir=None):
|
|||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
kb_path = str(output_dir / "kb")
|
||||
kb.dump(kb_path)
|
||||
kb.to_disk(kb_path)
|
||||
print()
|
||||
print("Saved KB to", kb_path)
|
||||
|
||||
|
@ -96,9 +95,8 @@ def main(model, output_dir=None):
|
|||
print("Loading vocab from", vocab_path)
|
||||
print("Loading KB from", kb_path)
|
||||
vocab2 = Vocab().from_disk(vocab_path)
|
||||
kb2 = KnowledgeBase(entity_vector_length=1)
|
||||
kb.initialize(vocab2)
|
||||
kb2.load_bulk(kb_path)
|
||||
kb2 = KnowledgeBase(vocab2, entity_vector_length=1)
|
||||
kb2.from_disk(kb_path)
|
||||
print()
|
||||
_print_kb(kb2)
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ def main(kb_path, vocab_path, output_dir=None, n_iter=50):
|
|||
if "entity_linker" not in nlp.pipe_names:
|
||||
print("Loading Knowledge Base from '%s'" % kb_path)
|
||||
cfg = {
|
||||
"kb": {
|
||||
"kb_loader": {
|
||||
"@assets": "spacy.KBFromFile.v1",
|
||||
"vocab_path": vocab_path,
|
||||
"kb_path": kb_path,
|
||||
|
|
|
@ -477,6 +477,10 @@ class Errors:
|
|||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the "
|
||||
"provided argument {loc} is an existing directory.")
|
||||
E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does "
|
||||
"not seem to exist.")
|
||||
E930 = ("Received invalid get_examples callback in {name}.begin_training. "
|
||||
"Expected function that returns an iterable of Example objects but "
|
||||
"got: {obj}")
|
||||
|
@ -504,8 +508,6 @@ class Errors:
|
|||
"not found in pipeline. Available components: {opts}")
|
||||
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
||||
"nlp object, but got: {source}")
|
||||
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
|
||||
"call kb.initialize()?")
|
||||
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
||||
"a string value from {expected} but got: '{arg}'")
|
||||
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
|
||||
|
@ -612,8 +614,6 @@ class Errors:
|
|||
"of the training data in spaCy 3.0 onwards. The 'update' "
|
||||
"function should now be called with a batch of 'Example' "
|
||||
"objects, instead of (text, annotation) tuples. ")
|
||||
E990 = ("An entity linking component needs to be initialized with a "
|
||||
"KnowledgeBase object, but found {type} instead.")
|
||||
E991 = ("The function 'select_pipes' should be called with either a "
|
||||
"'disable' argument to list the names of the pipe components "
|
||||
"that should be disabled, or with an 'enable' argument that "
|
||||
|
|
|
@ -140,7 +140,7 @@ cdef class KnowledgeBase:
|
|||
self._entries.push_back(entry)
|
||||
self._aliases_table.push_back(alias)
|
||||
|
||||
cpdef load_bulk(self, loc)
|
||||
cpdef from_disk(self, loc)
|
||||
cpdef set_entities(self, entity_list, freq_list, vector_list)
|
||||
|
||||
|
||||
|
|
54
spacy/kb.pyx
54
spacy/kb.pyx
|
@ -1,4 +1,5 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
from typing import Iterator
|
||||
from cymem.cymem cimport Pool
|
||||
from preshed.maps cimport PreshMap
|
||||
from cpython.exc cimport PyErr_SetFromErrno
|
||||
|
@ -64,6 +65,16 @@ cdef class Candidate:
|
|||
return self.prior_prob
|
||||
|
||||
|
||||
def get_candidates(KnowledgeBase kb, span) -> Iterator[Candidate]:
|
||||
"""
|
||||
Return candidate entities for a given span by using the text of the span as the alias
|
||||
and fetching appropriate entries from the index.
|
||||
This particular function is optimized to work with the built-in KB functionality,
|
||||
but any other custom candidate generation method can be used in combination with the KB as well.
|
||||
"""
|
||||
return kb.get_alias_candidates(span.text)
|
||||
|
||||
|
||||
cdef class KnowledgeBase:
|
||||
"""A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases,
|
||||
to support entity linking of named entities to real-world concepts.
|
||||
|
@ -71,25 +82,16 @@ cdef class KnowledgeBase:
|
|||
DOCS: https://spacy.io/api/kb
|
||||
"""
|
||||
|
||||
def __init__(self, entity_vector_length):
|
||||
"""Create a KnowledgeBase. Make sure to call kb.initialize() before using it."""
|
||||
def __init__(self, Vocab vocab, entity_vector_length):
|
||||
"""Create a KnowledgeBase."""
|
||||
self.mem = Pool()
|
||||
self.entity_vector_length = entity_vector_length
|
||||
|
||||
self._entry_index = PreshMap()
|
||||
self._alias_index = PreshMap()
|
||||
self.vocab = None
|
||||
|
||||
|
||||
def initialize(self, Vocab vocab):
|
||||
self.vocab = vocab
|
||||
self.vocab.strings.add("")
|
||||
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
||||
|
||||
def require_vocab(self):
|
||||
if self.vocab is None:
|
||||
raise ValueError(Errors.E946)
|
||||
|
||||
@property
|
||||
def entity_vector_length(self):
|
||||
"""RETURNS (uint64): length of the entity vectors"""
|
||||
|
@ -102,14 +104,12 @@ cdef class KnowledgeBase:
|
|||
return len(self._entry_index)
|
||||
|
||||
def get_entity_strings(self):
|
||||
self.require_vocab()
|
||||
return [self.vocab.strings[x] for x in self._entry_index]
|
||||
|
||||
def get_size_aliases(self):
|
||||
return len(self._alias_index)
|
||||
|
||||
def get_alias_strings(self):
|
||||
self.require_vocab()
|
||||
return [self.vocab.strings[x] for x in self._alias_index]
|
||||
|
||||
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
|
||||
|
@ -117,7 +117,6 @@ cdef class KnowledgeBase:
|
|||
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
||||
Return the hash of the entity ID/name at the end.
|
||||
"""
|
||||
self.require_vocab()
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
|
||||
# Return if this entity was added before
|
||||
|
@ -140,7 +139,6 @@ cdef class KnowledgeBase:
|
|||
return entity_hash
|
||||
|
||||
cpdef set_entities(self, entity_list, freq_list, vector_list):
|
||||
self.require_vocab()
|
||||
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
||||
raise ValueError(Errors.E140)
|
||||
|
||||
|
@ -176,12 +174,10 @@ cdef class KnowledgeBase:
|
|||
i += 1
|
||||
|
||||
def contains_entity(self, unicode entity):
|
||||
self.require_vocab()
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
return entity_hash in self._entry_index
|
||||
|
||||
def contains_alias(self, unicode alias):
|
||||
self.require_vocab()
|
||||
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||
return alias_hash in self._alias_index
|
||||
|
||||
|
@ -190,7 +186,6 @@ cdef class KnowledgeBase:
|
|||
For a given alias, add its potential entities and prior probabilies to the KB.
|
||||
Return the alias_hash at the end
|
||||
"""
|
||||
self.require_vocab()
|
||||
# Throw an error if the length of entities and probabilities are not the same
|
||||
if not len(entities) == len(probabilities):
|
||||
raise ValueError(Errors.E132.format(alias=alias,
|
||||
|
@ -234,7 +229,6 @@ cdef class KnowledgeBase:
|
|||
Throw an error if this entity+prior prob would exceed the sum of 1.
|
||||
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
|
||||
"""
|
||||
self.require_vocab()
|
||||
# Check if the alias exists in the KB
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
|
@ -274,14 +268,12 @@ cdef class KnowledgeBase:
|
|||
alias_entry.probs = probs
|
||||
self._aliases_table[alias_index] = alias_entry
|
||||
|
||||
|
||||
def get_candidates(self, unicode alias):
|
||||
def get_alias_candidates(self, unicode alias) -> Iterator[Candidate]:
|
||||
"""
|
||||
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If the alias is not known in the KB, and empty list is returned.
|
||||
"""
|
||||
self.require_vocab()
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
return []
|
||||
|
@ -298,7 +290,6 @@ cdef class KnowledgeBase:
|
|||
if entry_index != 0]
|
||||
|
||||
def get_vector(self, unicode entity):
|
||||
self.require_vocab()
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
|
||||
# Return an empty list if this entity is unknown in this KB
|
||||
|
@ -311,7 +302,6 @@ cdef class KnowledgeBase:
|
|||
def get_prior_prob(self, unicode entity, unicode alias):
|
||||
""" Return the prior probability of a given alias being linked to a given entity,
|
||||
or return 0.0 when this combination is not known in the knowledge base"""
|
||||
self.require_vocab()
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
|
||||
|
@ -329,8 +319,7 @@ cdef class KnowledgeBase:
|
|||
return 0.0
|
||||
|
||||
|
||||
def dump(self, loc):
|
||||
self.require_vocab()
|
||||
def to_disk(self, loc):
|
||||
cdef Writer writer = Writer(loc)
|
||||
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
||||
|
||||
|
@ -370,7 +359,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
writer.close()
|
||||
|
||||
cpdef load_bulk(self, loc):
|
||||
cpdef from_disk(self, loc):
|
||||
cdef hash_t entity_hash
|
||||
cdef hash_t alias_hash
|
||||
cdef int64_t entry_index
|
||||
|
@ -462,12 +451,11 @@ cdef class KnowledgeBase:
|
|||
|
||||
cdef class Writer:
|
||||
def __init__(self, object loc):
|
||||
if path.exists(loc):
|
||||
assert not path.isdir(loc), f"{loc} is directory"
|
||||
if isinstance(loc, Path):
|
||||
loc = bytes(loc)
|
||||
if path.exists(loc):
|
||||
assert not path.isdir(loc), "%s is directory." % loc
|
||||
if path.isdir(loc):
|
||||
raise ValueError(Errors.E928.format(loc=loc))
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||
if not self._fp:
|
||||
|
@ -511,8 +499,10 @@ cdef class Reader:
|
|||
def __init__(self, object loc):
|
||||
if isinstance(loc, Path):
|
||||
loc = bytes(loc)
|
||||
assert path.exists(loc)
|
||||
assert not path.isdir(loc)
|
||||
if not path.exists(loc):
|
||||
raise ValueError(Errors.E929.format(loc=loc))
|
||||
if path.isdir(loc):
|
||||
raise ValueError(Errors.E928.format(loc=loc))
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
self._fp = fopen(<char*>bytes_loc, 'rb')
|
||||
if not self._fp:
|
||||
|
|
|
@ -772,9 +772,9 @@ class Language:
|
|||
self.remove_pipe(name)
|
||||
if not len(self.pipeline) or pipe_index == len(self.pipeline):
|
||||
# we have no components to insert before/after, or we're replacing the last component
|
||||
self.add_pipe(factory_name, name=name)
|
||||
self.add_pipe(factory_name, name=name, config=config, validate=validate)
|
||||
else:
|
||||
self.add_pipe(factory_name, name=name, before=pipe_index)
|
||||
self.add_pipe(factory_name, name=name, before=pipe_index, config=config, validate=validate)
|
||||
|
||||
def rename_pipe(self, old_name: str, new_name: str) -> None:
|
||||
"""Rename a pipeline component.
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Callable, Iterable
|
||||
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
|
||||
from thinc.api import Model, Maxout, Linear
|
||||
|
||||
from ...util import registry
|
||||
from ...kb import KnowledgeBase
|
||||
from ...kb import KnowledgeBase, Candidate, get_candidates
|
||||
from ...vocab import Vocab
|
||||
|
||||
|
||||
|
@ -25,15 +25,21 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
|
|||
|
||||
|
||||
@registry.assets.register("spacy.KBFromFile.v1")
|
||||
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
|
||||
vocab = Vocab().from_disk(vocab_path)
|
||||
kb = KnowledgeBase(entity_vector_length=1)
|
||||
kb.initialize(vocab)
|
||||
kb.load_bulk(kb_path)
|
||||
return kb
|
||||
def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
|
||||
def kb_from_file(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.from_disk(kb_path)
|
||||
return kb
|
||||
return kb_from_file
|
||||
|
||||
|
||||
@registry.assets.register("spacy.EmptyKB.v1")
|
||||
def empty_kb(entity_vector_length: int) -> KnowledgeBase:
|
||||
kb = KnowledgeBase(entity_vector_length=entity_vector_length)
|
||||
return kb
|
||||
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
|
||||
def empty_kb_factory(vocab):
|
||||
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
|
||||
return empty_kb_factory
|
||||
|
||||
|
||||
@registry.assets.register("spacy.CandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
||||
return get_candidates
|
||||
|
|
|
@ -6,7 +6,7 @@ from thinc.api import CosineDistance, get_array_module, Model, Optimizer, Config
|
|||
from thinc.api import set_dropout_rate
|
||||
import warnings
|
||||
|
||||
from ..kb import KnowledgeBase
|
||||
from ..kb import KnowledgeBase, Candidate
|
||||
from ..tokens import Doc
|
||||
from .pipe import Pipe, deserialize_config
|
||||
from ..language import Language
|
||||
|
@ -32,35 +32,30 @@ subword_features = true
|
|||
"""
|
||||
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
default_kb_config = """
|
||||
[kb]
|
||||
@assets = "spacy.EmptyKB.v1"
|
||||
entity_vector_length = 64
|
||||
"""
|
||||
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
|
||||
|
||||
|
||||
@Language.factory(
|
||||
"entity_linker",
|
||||
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
|
||||
assigns=["token.ent_kb_id"],
|
||||
default_config={
|
||||
"kb": DEFAULT_NEL_KB,
|
||||
"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 64},
|
||||
"model": DEFAULT_NEL_MODEL,
|
||||
"labels_discard": [],
|
||||
"incl_prior": True,
|
||||
"incl_context": True,
|
||||
"get_candidates": {"@assets": "spacy.CandidateGenerator.v1"},
|
||||
},
|
||||
)
|
||||
def make_entity_linker(
|
||||
nlp: Language,
|
||||
name: str,
|
||||
model: Model,
|
||||
kb: KnowledgeBase,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase],
|
||||
*,
|
||||
labels_discard: Iterable[str],
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
|
||||
):
|
||||
"""Construct an EntityLinker component.
|
||||
|
||||
|
@ -76,10 +71,11 @@ def make_entity_linker(
|
|||
nlp.vocab,
|
||||
model,
|
||||
name,
|
||||
kb=kb,
|
||||
kb_loader=kb_loader,
|
||||
labels_discard=labels_discard,
|
||||
incl_prior=incl_prior,
|
||||
incl_context=incl_context,
|
||||
get_candidates=get_candidates,
|
||||
)
|
||||
|
||||
|
||||
|
@ -97,10 +93,11 @@ class EntityLinker(Pipe):
|
|||
model: Model,
|
||||
name: str = "entity_linker",
|
||||
*,
|
||||
kb: KnowledgeBase,
|
||||
kb_loader: Callable[[Vocab], KnowledgeBase],
|
||||
labels_discard: Iterable[str],
|
||||
incl_prior: bool,
|
||||
incl_context: bool,
|
||||
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
|
||||
) -> None:
|
||||
"""Initialize an entity linker.
|
||||
|
||||
|
@ -108,7 +105,7 @@ class EntityLinker(Pipe):
|
|||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases.
|
||||
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
|
||||
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
|
||||
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
|
||||
incl_context (bool): Whether or not to include the local context in the model.
|
||||
|
@ -119,17 +116,12 @@ class EntityLinker(Pipe):
|
|||
self.model = model
|
||||
self.name = name
|
||||
cfg = {
|
||||
"kb": kb,
|
||||
"labels_discard": list(labels_discard),
|
||||
"incl_prior": incl_prior,
|
||||
"incl_context": incl_context,
|
||||
}
|
||||
if not isinstance(kb, KnowledgeBase):
|
||||
raise ValueError(Errors.E990.format(type=type(self.kb)))
|
||||
kb.initialize(vocab)
|
||||
self.kb = kb
|
||||
if "kb" in cfg:
|
||||
del cfg["kb"] # we don't want to duplicate its serialization
|
||||
self.kb = kb_loader(self.vocab)
|
||||
self.get_candidates = get_candidates
|
||||
self.cfg = dict(cfg)
|
||||
self.distance = CosineDistance(normalize=False)
|
||||
# how many neightbour sentences to take into account
|
||||
|
@ -326,10 +318,11 @@ class EntityLinker(Pipe):
|
|||
end_token = sentences[end_sentence].end
|
||||
sent_doc = doc[start_token:end_token].as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||
xp = get_array_module(sentence_encoding)
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
xp = self.model.ops.xp
|
||||
if self.cfg.get("incl_context"):
|
||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
for ent in sent.ents:
|
||||
entity_count += 1
|
||||
to_discard = self.cfg.get("labels_discard", [])
|
||||
|
@ -337,7 +330,7 @@ class EntityLinker(Pipe):
|
|||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
else:
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
candidates = self.get_candidates(self.kb, ent)
|
||||
if not candidates:
|
||||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
|
@ -421,10 +414,9 @@ class EntityLinker(Pipe):
|
|||
DOCS: https://spacy.io/api/entitylinker#to_disk
|
||||
"""
|
||||
serialize = {}
|
||||
self.cfg["entity_width"] = self.kb.entity_vector_length
|
||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||
serialize["kb"] = lambda p: self.kb.to_disk(p)
|
||||
serialize["model"] = lambda p: self.model.to_disk(p)
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
|
@ -446,15 +438,10 @@ class EntityLinker(Pipe):
|
|||
except AttributeError:
|
||||
raise ValueError(Errors.E149) from None
|
||||
|
||||
def load_kb(p):
|
||||
self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
|
||||
self.kb.initialize(self.vocab)
|
||||
self.kb.load_bulk(p)
|
||||
|
||||
deserialize = {}
|
||||
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
||||
deserialize["kb"] = load_kb
|
||||
deserialize["kb"] = lambda p: self.kb.from_disk(p)
|
||||
deserialize["model"] = load_model
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
|
|
@ -68,7 +68,6 @@ class Tagger(Pipe):
|
|||
name (str): The component instance name, used to add entries to the
|
||||
losses during training.
|
||||
labels (List): The set of labels. Defaults to None.
|
||||
set_morphology (bool): Whether to set morphological features.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#init
|
||||
"""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Callable, Iterable
|
||||
import pytest
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from spacy.kb import KnowledgeBase, get_candidates, Candidate
|
||||
|
||||
from spacy import util, registry
|
||||
from spacy.gold import Example
|
||||
|
@ -21,8 +22,7 @@ def assert_almost_equal(a, b):
|
|||
|
||||
def test_kb_valid_entities(nlp):
|
||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||
mykb = KnowledgeBase(entity_vector_length=3)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
|
||||
|
@ -51,8 +51,7 @@ def test_kb_valid_entities(nlp):
|
|||
|
||||
def test_kb_invalid_entities(nlp):
|
||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||
mykb = KnowledgeBase(entity_vector_length=1)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||
|
@ -68,8 +67,7 @@ def test_kb_invalid_entities(nlp):
|
|||
|
||||
def test_kb_invalid_probabilities(nlp):
|
||||
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||
mykb = KnowledgeBase(entity_vector_length=1)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||
|
@ -83,8 +81,7 @@ def test_kb_invalid_probabilities(nlp):
|
|||
|
||||
def test_kb_invalid_combination(nlp):
|
||||
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
||||
mykb = KnowledgeBase(entity_vector_length=1)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||
|
@ -100,8 +97,7 @@ def test_kb_invalid_combination(nlp):
|
|||
|
||||
def test_kb_invalid_entity_vector(nlp):
|
||||
"""Test the invalid construction of a KB with non-matching entity vector lengths"""
|
||||
mykb = KnowledgeBase(entity_vector_length=3)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
|
||||
|
@ -117,14 +113,14 @@ def test_kb_default(nlp):
|
|||
assert len(entity_linker.kb) == 0
|
||||
assert entity_linker.kb.get_size_entities() == 0
|
||||
assert entity_linker.kb.get_size_aliases() == 0
|
||||
# default value from pipeline.entity_linker
|
||||
# 64 is the default value from pipeline.entity_linker
|
||||
assert entity_linker.kb.entity_vector_length == 64
|
||||
|
||||
|
||||
def test_kb_custom_length(nlp):
|
||||
"""Test that the default (empty) KB can be configured with a custom entity length"""
|
||||
entity_linker = nlp.add_pipe(
|
||||
"entity_linker", config={"kb": {"entity_vector_length": 35}}
|
||||
"entity_linker", config={"kb_loader": {"entity_vector_length": 35}}
|
||||
)
|
||||
assert len(entity_linker.kb) == 0
|
||||
assert entity_linker.kb.get_size_entities() == 0
|
||||
|
@ -141,7 +137,7 @@ def test_kb_undefined(nlp):
|
|||
|
||||
def test_kb_empty(nlp):
|
||||
"""Test that the EL can't train with an empty KB"""
|
||||
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
|
||||
config = {"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
assert len(entity_linker.kb) == 0
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -150,8 +146,13 @@ def test_kb_empty(nlp):
|
|||
|
||||
def test_candidate_generation(nlp):
|
||||
"""Test correct candidate generation"""
|
||||
mykb = KnowledgeBase(entity_vector_length=1)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
doc = nlp("douglas adam Adam shrubbery")
|
||||
|
||||
douglas_ent = doc[0:1]
|
||||
adam_ent = doc[1:2]
|
||||
Adam_ent = doc[2:3]
|
||||
shrubbery_ent = doc[3:4]
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||
|
@ -163,21 +164,76 @@ def test_candidate_generation(nlp):
|
|||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert len(mykb.get_candidates("douglas")) == 2
|
||||
assert len(mykb.get_candidates("adam")) == 1
|
||||
assert len(mykb.get_candidates("shrubbery")) == 0
|
||||
assert len(get_candidates(mykb, douglas_ent)) == 2
|
||||
assert len(get_candidates(mykb, adam_ent)) == 1
|
||||
assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
|
||||
assert len(get_candidates(mykb, shrubbery_ent)) == 0
|
||||
|
||||
# test the content of the candidates
|
||||
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
|
||||
assert mykb.get_candidates("adam")[0].alias_ == "adam"
|
||||
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 12)
|
||||
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||
assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2"
|
||||
assert get_candidates(mykb, adam_ent)[0].alias_ == "adam"
|
||||
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
|
||||
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
|
||||
|
||||
|
||||
def test_el_pipe_configuration(nlp):
|
||||
"""Test correct candidate generation as part of the EL pipe"""
|
||||
nlp.add_pipe("sentencizer")
|
||||
pattern = {"label": "PERSON", "pattern": [{"LOWER": "douglas"}]}
|
||||
ruler = nlp.add_pipe("entity_ruler")
|
||||
ruler.add_patterns([pattern])
|
||||
|
||||
@registry.assets.register("myAdamKB.v1")
|
||||
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||
kb.add_alias(
|
||||
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
|
||||
)
|
||||
return kb
|
||||
|
||||
return create_kb
|
||||
|
||||
# run an EL pipe without a trained context encoder, to check the candidate generation step only
|
||||
nlp.add_pipe(
|
||||
"entity_linker",
|
||||
config={"kb_loader": {"@assets": "myAdamKB.v1"}, "incl_context": False},
|
||||
)
|
||||
# With the default get_candidates function, matching is case-sensitive
|
||||
text = "Douglas and douglas are not the same."
|
||||
doc = nlp(text)
|
||||
assert doc[0].ent_kb_id_ == "NIL"
|
||||
assert doc[1].ent_kb_id_ == ""
|
||||
assert doc[2].ent_kb_id_ == "Q2"
|
||||
|
||||
def get_lowercased_candidates(kb, span):
|
||||
return kb.get_alias_candidates(span.text.lower())
|
||||
|
||||
@registry.assets.register("spacy.LowercaseCandidateGenerator.v1")
|
||||
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
|
||||
return get_lowercased_candidates
|
||||
|
||||
# replace the pipe with a new one with with a different candidate generator
|
||||
nlp.replace_pipe(
|
||||
"entity_linker",
|
||||
"entity_linker",
|
||||
config={
|
||||
"kb_loader": {"@assets": "myAdamKB.v1"},
|
||||
"incl_context": False,
|
||||
"get_candidates": {"@assets": "spacy.LowercaseCandidateGenerator.v1"},
|
||||
},
|
||||
)
|
||||
doc = nlp(text)
|
||||
assert doc[0].ent_kb_id_ == "Q2"
|
||||
assert doc[1].ent_kb_id_ == ""
|
||||
assert doc[2].ent_kb_id_ == "Q2"
|
||||
|
||||
|
||||
def test_append_alias(nlp):
|
||||
"""Test that we can append additional alias-entity pairs"""
|
||||
mykb = KnowledgeBase(entity_vector_length=1)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||
|
@ -189,26 +245,25 @@ def test_append_alias(nlp):
|
|||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert len(mykb.get_candidates("douglas")) == 2
|
||||
assert len(mykb.get_alias_candidates("douglas")) == 2
|
||||
|
||||
# append an alias
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
|
||||
|
||||
# test the size of the relevant candidates has been incremented
|
||||
assert len(mykb.get_candidates("douglas")) == 3
|
||||
assert len(mykb.get_alias_candidates("douglas")) == 3
|
||||
|
||||
# append the same alias-entity pair again should not work (will throw a warning)
|
||||
with pytest.warns(UserWarning):
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
|
||||
|
||||
# test the size of the relevant candidates remained unchanged
|
||||
assert len(mykb.get_candidates("douglas")) == 3
|
||||
assert len(mykb.get_alias_candidates("douglas")) == 3
|
||||
|
||||
|
||||
def test_append_invalid_alias(nlp):
|
||||
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
|
||||
mykb = KnowledgeBase(entity_vector_length=1)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||
|
@ -228,16 +283,18 @@ def test_preserving_links_asdoc(nlp):
|
|||
"""Test that Span.as_doc preserves the existing entity links"""
|
||||
|
||||
@registry.assets.register("myLocationsKB.v1")
|
||||
def dummy_kb() -> KnowledgeBase:
|
||||
mykb = KnowledgeBase(entity_vector_length=1)
|
||||
mykb.initialize(nlp.vocab)
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||
return mykb
|
||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||
return mykb
|
||||
|
||||
return create_kb
|
||||
|
||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||
nlp.add_pipe("sentencizer")
|
||||
|
@ -247,7 +304,7 @@ def test_preserving_links_asdoc(nlp):
|
|||
]
|
||||
ruler = nlp.add_pipe("entity_ruler")
|
||||
ruler.add_patterns(patterns)
|
||||
el_config = {"kb": {"@assets": "myLocationsKB.v1"}, "incl_prior": False}
|
||||
el_config = {"kb_loader": {"@assets": "myLocationsKB.v1"}, "incl_prior": False}
|
||||
el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True)
|
||||
el_pipe.begin_training(lambda: [])
|
||||
el_pipe.incl_context = False
|
||||
|
@ -331,24 +388,28 @@ def test_overfitting_IO():
|
|||
train_examples.append(Example.from_dict(doc, annotation))
|
||||
|
||||
@registry.assets.register("myOverfittingKB.v1")
|
||||
def dummy_kb() -> KnowledgeBase:
|
||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||
# Q2146908 (Russ Cochran): American golfer
|
||||
# Q7381115 (Russ Cochran): publisher
|
||||
mykb = KnowledgeBase(entity_vector_length=3)
|
||||
mykb.initialize(nlp.vocab)
|
||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||
mykb.add_alias(
|
||||
alias="Russ Cochran",
|
||||
entities=["Q2146908", "Q7381115"],
|
||||
probabilities=[0.5, 0.5],
|
||||
)
|
||||
return mykb
|
||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
# create artificial KB - assign same prior weight to the two russ cochran's
|
||||
# Q2146908 (Russ Cochran): American golfer
|
||||
# Q7381115 (Russ Cochran): publisher
|
||||
mykb = KnowledgeBase(vocab, entity_vector_length=3)
|
||||
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
|
||||
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
|
||||
mykb.add_alias(
|
||||
alias="Russ Cochran",
|
||||
entities=["Q2146908", "Q7381115"],
|
||||
probabilities=[0.5, 0.5],
|
||||
)
|
||||
return mykb
|
||||
|
||||
return create_kb
|
||||
|
||||
# Create the Entity Linker component and add it to the pipeline
|
||||
nlp.add_pipe(
|
||||
"entity_linker", config={"kb": {"@assets": "myOverfittingKB.v1"}}, last=True
|
||||
"entity_linker",
|
||||
config={"kb_loader": {"@assets": "myOverfittingKB.v1"}},
|
||||
last=True,
|
||||
)
|
||||
|
||||
# train the NEL pipe
|
||||
|
|
|
@ -78,6 +78,14 @@ def test_replace_last_pipe(nlp):
|
|||
assert nlp.pipe_names == ["sentencizer", "ner"]
|
||||
|
||||
|
||||
def test_replace_pipe_config(nlp):
|
||||
nlp.add_pipe("entity_linker")
|
||||
nlp.add_pipe("sentencizer")
|
||||
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == True
|
||||
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
|
||||
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == False
|
||||
|
||||
|
||||
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
|
||||
def test_rename_pipe(nlp, old_name, new_name):
|
||||
with pytest.raises(ValueError):
|
||||
|
|
|
@ -139,8 +139,7 @@ def test_issue4665():
|
|||
def test_issue4674():
|
||||
"""Test that setting entities with overlapping identifiers does not mess up IO"""
|
||||
nlp = English()
|
||||
kb = KnowledgeBase(entity_vector_length=3)
|
||||
kb.initialize(nlp.vocab)
|
||||
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
vector1 = [0.9, 1.1, 1.01]
|
||||
vector2 = [1.8, 2.25, 2.01]
|
||||
with pytest.warns(UserWarning):
|
||||
|
@ -156,10 +155,9 @@ def test_issue4674():
|
|||
if not dir_path.exists():
|
||||
dir_path.mkdir()
|
||||
file_path = dir_path / "kb"
|
||||
kb.dump(str(file_path))
|
||||
kb2 = KnowledgeBase(entity_vector_length=3)
|
||||
kb2.initialize(nlp.vocab)
|
||||
kb2.load_bulk(str(file_path))
|
||||
kb.to_disk(str(file_path))
|
||||
kb2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
kb2.from_disk(str(file_path))
|
||||
assert kb2.get_size_entities() == 1
|
||||
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Callable
|
||||
import warnings
|
||||
from unittest import TestCase
|
||||
import pytest
|
||||
|
@ -70,13 +71,14 @@ def entity_linker():
|
|||
nlp = Language()
|
||||
|
||||
@registry.assets.register("TestIssue5230KB.v1")
|
||||
def dummy_kb() -> KnowledgeBase:
|
||||
kb = KnowledgeBase(entity_vector_length=1)
|
||||
kb.initialize(nlp.vocab)
|
||||
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
||||
return kb
|
||||
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=1)
|
||||
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
|
||||
return kb
|
||||
return create_kb
|
||||
|
||||
config = {"kb": {"@assets": "TestIssue5230KB.v1"}}
|
||||
config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
# need to add model for two reasons:
|
||||
# 1. no model leads to error in serialization,
|
||||
|
@ -121,19 +123,17 @@ def test_writer_with_path_py35():
|
|||
|
||||
def test_save_and_load_knowledge_base():
|
||||
nlp = Language()
|
||||
kb = KnowledgeBase(entity_vector_length=1)
|
||||
kb.initialize(nlp.vocab)
|
||||
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
with make_tempdir() as d:
|
||||
path = d / "kb"
|
||||
try:
|
||||
kb.dump(path)
|
||||
kb.to_disk(path)
|
||||
except Exception as e:
|
||||
pytest.fail(str(e))
|
||||
|
||||
try:
|
||||
kb_loaded = KnowledgeBase(entity_vector_length=1)
|
||||
kb_loaded.initialize(nlp.vocab)
|
||||
kb_loaded.load_bulk(path)
|
||||
kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
kb_loaded.from_disk(path)
|
||||
except Exception as e:
|
||||
pytest.fail(str(e))
|
||||
|
||||
|
|
|
@ -1,4 +1,8 @@
|
|||
from spacy.util import ensure_path
|
||||
from typing import Callable
|
||||
|
||||
from spacy import util
|
||||
from spacy.lang.en import English
|
||||
from spacy.util import ensure_path, registry
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
@ -15,20 +19,16 @@ def test_serialize_kb_disk(en_vocab):
|
|||
if not dir_path.exists():
|
||||
dir_path.mkdir()
|
||||
file_path = dir_path / "kb"
|
||||
kb1.dump(str(file_path))
|
||||
|
||||
kb2 = KnowledgeBase(entity_vector_length=3)
|
||||
kb2.initialize(en_vocab)
|
||||
kb2.load_bulk(str(file_path))
|
||||
kb1.to_disk(str(file_path))
|
||||
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
|
||||
kb2.from_disk(str(file_path))
|
||||
|
||||
# final assertions
|
||||
_check_kb(kb2)
|
||||
|
||||
|
||||
def _get_dummy_kb(vocab):
|
||||
kb = KnowledgeBase(entity_vector_length=3)
|
||||
kb.initialize(vocab)
|
||||
|
||||
kb = KnowledgeBase(vocab, entity_vector_length=3)
|
||||
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
|
||||
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])
|
||||
kb.add_entity(entity="Q007", freq=7, entity_vector=[0, 0, 7])
|
||||
|
@ -61,7 +61,7 @@ def _check_kb(kb):
|
|||
assert alias_string not in kb.get_alias_strings()
|
||||
|
||||
# check candidates & probabilities
|
||||
candidates = sorted(kb.get_candidates("double07"), key=lambda x: x.entity_)
|
||||
candidates = sorted(kb.get_alias_candidates("double07"), key=lambda x: x.entity_)
|
||||
assert len(candidates) == 2
|
||||
|
||||
assert candidates[0].entity_ == "Q007"
|
||||
|
@ -75,3 +75,47 @@ def _check_kb(kb):
|
|||
assert candidates[1].entity_vector == [7, 1, 0]
|
||||
assert candidates[1].alias_ == "double07"
|
||||
assert 0.099 < candidates[1].prior_prob < 0.101
|
||||
|
||||
|
||||
def test_serialize_subclassed_kb():
|
||||
"""Check that IO of a custom KB works fine as part of an EL pipe."""
|
||||
|
||||
class SubKnowledgeBase(KnowledgeBase):
|
||||
def __init__(self, vocab, entity_vector_length, custom_field):
|
||||
super().__init__(vocab, entity_vector_length)
|
||||
self.custom_field = custom_field
|
||||
|
||||
@registry.assets.register("spacy.CustomKB.v1")
|
||||
def custom_kb(
|
||||
entity_vector_length: int, custom_field: int
|
||||
) -> Callable[["Vocab"], KnowledgeBase]:
|
||||
def custom_kb_factory(vocab):
|
||||
return SubKnowledgeBase(
|
||||
vocab=vocab,
|
||||
entity_vector_length=entity_vector_length,
|
||||
custom_field=custom_field,
|
||||
)
|
||||
|
||||
return custom_kb_factory
|
||||
|
||||
nlp = English()
|
||||
config = {
|
||||
"kb_loader": {
|
||||
"@assets": "spacy.CustomKB.v1",
|
||||
"entity_vector_length": 342,
|
||||
"custom_field": 666,
|
||||
}
|
||||
}
|
||||
entity_linker = nlp.add_pipe("entity_linker", config=config)
|
||||
assert type(entity_linker.kb) == SubKnowledgeBase
|
||||
assert entity_linker.kb.entity_vector_length == 342
|
||||
assert entity_linker.kb.custom_field == 666
|
||||
|
||||
# Make sure the custom KB is serialized correctly
|
||||
with make_tempdir() as tmp_dir:
|
||||
nlp.to_disk(tmp_dir)
|
||||
nlp2 = util.load_model_from_path(tmp_dir)
|
||||
entity_linker2 = nlp2.get_pipe("entity_linker")
|
||||
assert type(entity_linker2.kb) == SubKnowledgeBase
|
||||
assert entity_linker2.kb.entity_vector_length == 342
|
||||
assert entity_linker2.kb.custom_field == 666
|
||||
|
|
|
@ -200,21 +200,21 @@ probability of the fact that the mention links to the entity ID.
|
|||
| `alias` | The textual mention or alias. ~~str~~ |
|
||||
| **RETURNS** | The prior probability of the `alias` referring to the `entity`. ~~float~~ |
|
||||
|
||||
## KnowledgeBase.dump {#dump tag="method"}
|
||||
## KnowledgeBase.to_disk {#to_disk tag="method"}
|
||||
|
||||
Save the current state of the knowledge base to a directory.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> kb.dump(loc)
|
||||
> kb.to_disk(loc)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ----- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `loc` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||
|
||||
## KnowledgeBase.load_bulk {#load_bulk tag="method"}
|
||||
## KnowledgeBase.from_disk {#from_disk tag="method"}
|
||||
|
||||
Restore the state of the knowledge base from a given directory. Note that the
|
||||
[`Vocab`](/api/vocab) should also be the same as the one used to create the KB.
|
||||
|
@ -226,7 +226,7 @@ Restore the state of the knowledge base from a given directory. Note that the
|
|||
> from spacy.vocab import Vocab
|
||||
> vocab = Vocab().from_disk("/path/to/vocab")
|
||||
> kb = KnowledgeBase(vocab=vocab, entity_vector_length=64)
|
||||
> kb.load_bulk("/path/to/kb")
|
||||
> kb.from_disk("/path/to/kb")
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
|
|
Loading…
Reference in New Issue
Block a user