Define candidate generator in EL config (#5876)

* candidate generator as separate part of EL config

* update comment

* ent instead of str as input for candidate generation

* Span instead of str: correct type indication

* fix types

* unit test to create new candidate generator

* fix replace_pipe argument passing

* move error message, general cleanup

* add vocab back to KB constructor

* provide KB as callable from Vocab arg

* rename to kb_loader, fix KB serialization as part of the EL pipe

* fix typo

* reformatting

* cleanup

* fix comment

* fix wrongly duplicated code from merge conflict

* rename dump to to_disk

* from_disk instead of load_bulk

* update test after recent removal of set_morphology in tagger

* remove old doc
This commit is contained in:
Sofie Van Landeghem 2020-08-18 16:10:36 +02:00 committed by GitHub
parent 688e77562b
commit 358cbb21e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 272 additions and 180 deletions

View File

@ -15,7 +15,8 @@ import spacy.util
from bin.ud import conll17_ud_eval from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
from spacy.gold import Example from spacy.gold import Example
from spacy.util import compounding, minibatch, minibatch_by_words from spacy.util import compounding, minibatch
from spacy.gold.batchers import minibatch_by_words
from spacy.pipeline._parser_internals.nonproj import projectivize from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.matcher import Matcher from spacy.matcher import Matcher
from spacy import displacy from spacy import displacy

View File

@ -48,8 +48,7 @@ def main(model, output_dir=None):
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality. # You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
# For simplicity, we'll just use the original vector dimension here instead. # For simplicity, we'll just use the original vector dimension here instead.
vectors_dim = nlp.vocab.vectors.shape[1] vectors_dim = nlp.vocab.vectors.shape[1]
kb = KnowledgeBase(entity_vector_length=vectors_dim) kb = KnowledgeBase(nlp.vocab, entity_vector_length=vectors_dim)
kb.initialize(nlp.vocab)
# set up the data # set up the data
entity_ids = [] entity_ids = []
@ -81,7 +80,7 @@ def main(model, output_dir=None):
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()
kb_path = str(output_dir / "kb") kb_path = str(output_dir / "kb")
kb.dump(kb_path) kb.to_disk(kb_path)
print() print()
print("Saved KB to", kb_path) print("Saved KB to", kb_path)
@ -96,9 +95,8 @@ def main(model, output_dir=None):
print("Loading vocab from", vocab_path) print("Loading vocab from", vocab_path)
print("Loading KB from", kb_path) print("Loading KB from", kb_path)
vocab2 = Vocab().from_disk(vocab_path) vocab2 = Vocab().from_disk(vocab_path)
kb2 = KnowledgeBase(entity_vector_length=1) kb2 = KnowledgeBase(vocab2, entity_vector_length=1)
kb.initialize(vocab2) kb2.from_disk(kb_path)
kb2.load_bulk(kb_path)
print() print()
_print_kb(kb2) _print_kb(kb2)

View File

@ -83,7 +83,7 @@ def main(kb_path, vocab_path, output_dir=None, n_iter=50):
if "entity_linker" not in nlp.pipe_names: if "entity_linker" not in nlp.pipe_names:
print("Loading Knowledge Base from '%s'" % kb_path) print("Loading Knowledge Base from '%s'" % kb_path)
cfg = { cfg = {
"kb": { "kb_loader": {
"@assets": "spacy.KBFromFile.v1", "@assets": "spacy.KBFromFile.v1",
"vocab_path": vocab_path, "vocab_path": vocab_path,
"kb_path": kb_path, "kb_path": kb_path,

View File

@ -477,6 +477,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the "
"provided argument {loc} is an existing directory.")
E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does "
"not seem to exist.")
E930 = ("Received invalid get_examples callback in {name}.begin_training. " E930 = ("Received invalid get_examples callback in {name}.begin_training. "
"Expected function that returns an iterable of Example objects but " "Expected function that returns an iterable of Example objects but "
"got: {obj}") "got: {obj}")
@ -504,8 +508,6 @@ class Errors:
"not found in pipeline. Available components: {opts}") "not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded " E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
"nlp object, but got: {source}") "nlp object, but got: {source}")
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected " E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'") "a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected " E948 = ("Matcher.add received invalid 'patterns' argument: expected "
@ -612,8 +614,6 @@ class Errors:
"of the training data in spaCy 3.0 onwards. The 'update' " "of the training data in spaCy 3.0 onwards. The 'update' "
"function should now be called with a batch of 'Example' " "function should now be called with a batch of 'Example' "
"objects, instead of (text, annotation) tuples. ") "objects, instead of (text, annotation) tuples. ")
E990 = ("An entity linking component needs to be initialized with a "
"KnowledgeBase object, but found {type} instead.")
E991 = ("The function 'select_pipes' should be called with either a " E991 = ("The function 'select_pipes' should be called with either a "
"'disable' argument to list the names of the pipe components " "'disable' argument to list the names of the pipe components "
"that should be disabled, or with an 'enable' argument that " "that should be disabled, or with an 'enable' argument that "

View File

@ -140,7 +140,7 @@ cdef class KnowledgeBase:
self._entries.push_back(entry) self._entries.push_back(entry)
self._aliases_table.push_back(alias) self._aliases_table.push_back(alias)
cpdef load_bulk(self, loc) cpdef from_disk(self, loc)
cpdef set_entities(self, entity_list, freq_list, vector_list) cpdef set_entities(self, entity_list, freq_list, vector_list)

View File

@ -1,4 +1,5 @@
# cython: infer_types=True, profile=True # cython: infer_types=True, profile=True
from typing import Iterator
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from cpython.exc cimport PyErr_SetFromErrno from cpython.exc cimport PyErr_SetFromErrno
@ -64,6 +65,16 @@ cdef class Candidate:
return self.prior_prob return self.prior_prob
def get_candidates(KnowledgeBase kb, span) -> Iterator[Candidate]:
"""
Return candidate entities for a given span by using the text of the span as the alias
and fetching appropriate entries from the index.
This particular function is optimized to work with the built-in KB functionality,
but any other custom candidate generation method can be used in combination with the KB as well.
"""
return kb.get_alias_candidates(span.text)
cdef class KnowledgeBase: cdef class KnowledgeBase:
"""A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases, """A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases,
to support entity linking of named entities to real-world concepts. to support entity linking of named entities to real-world concepts.
@ -71,25 +82,16 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb DOCS: https://spacy.io/api/kb
""" """
def __init__(self, entity_vector_length): def __init__(self, Vocab vocab, entity_vector_length):
"""Create a KnowledgeBase. Make sure to call kb.initialize() before using it.""" """Create a KnowledgeBase."""
self.mem = Pool() self.mem = Pool()
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap() self._entry_index = PreshMap()
self._alias_index = PreshMap() self._alias_index = PreshMap()
self.vocab = None
def initialize(self, Vocab vocab):
self.vocab = vocab self.vocab = vocab
self.vocab.strings.add("") self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
def require_vocab(self):
if self.vocab is None:
raise ValueError(Errors.E946)
@property @property
def entity_vector_length(self): def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors""" """RETURNS (uint64): length of the entity vectors"""
@ -102,14 +104,12 @@ cdef class KnowledgeBase:
return len(self._entry_index) return len(self._entry_index)
def get_entity_strings(self): def get_entity_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._entry_index] return [self.vocab.strings[x] for x in self._entry_index]
def get_size_aliases(self): def get_size_aliases(self):
return len(self._alias_index) return len(self._alias_index)
def get_alias_strings(self): def get_alias_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._alias_index] return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float freq, vector[float] entity_vector): def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
@ -117,7 +117,6 @@ cdef class KnowledgeBase:
Add an entity to the KB, optionally specifying its log probability based on corpus frequency Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end. Return the hash of the entity ID/name at the end.
""" """
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before # Return if this entity was added before
@ -140,7 +139,6 @@ cdef class KnowledgeBase:
return entity_hash return entity_hash
cpdef set_entities(self, entity_list, freq_list, vector_list): cpdef set_entities(self, entity_list, freq_list, vector_list):
self.require_vocab()
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list): if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140) raise ValueError(Errors.E140)
@ -176,12 +174,10 @@ cdef class KnowledgeBase:
i += 1 i += 1
def contains_entity(self, unicode entity): def contains_entity(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity) cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index return entity_hash in self._entry_index
def contains_alias(self, unicode alias): def contains_alias(self, unicode alias):
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings.add(alias) cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index return alias_hash in self._alias_index
@ -190,7 +186,6 @@ cdef class KnowledgeBase:
For a given alias, add its potential entities and prior probabilies to the KB. For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end Return the alias_hash at the end
""" """
self.require_vocab()
# Throw an error if the length of entities and probabilities are not the same # Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities): if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias, raise ValueError(Errors.E132.format(alias=alias,
@ -234,7 +229,6 @@ cdef class KnowledgeBase:
Throw an error if this entity+prior prob would exceed the sum of 1. Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
""" """
self.require_vocab()
# Check if the alias exists in the KB # Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
@ -274,14 +268,12 @@ cdef class KnowledgeBase:
alias_entry.probs = probs alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry self._aliases_table[alias_index] = alias_entry
def get_alias_candidates(self, unicode alias) -> Iterator[Candidate]:
def get_candidates(self, unicode alias):
""" """
Return candidate entities for an alias. Each candidate defines the entity, the original alias, Return candidate entities for an alias. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity. and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned. If the alias is not known in the KB, and empty list is returned.
""" """
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index: if not alias_hash in self._alias_index:
return [] return []
@ -298,7 +290,6 @@ cdef class KnowledgeBase:
if entry_index != 0] if entry_index != 0]
def get_vector(self, unicode entity): def get_vector(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB # Return an empty list if this entity is unknown in this KB
@ -311,7 +302,6 @@ cdef class KnowledgeBase:
def get_prior_prob(self, unicode entity, unicode alias): def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity, """ Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base""" or return 0.0 when this combination is not known in the knowledge base"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity] cdef hash_t entity_hash = self.vocab.strings[entity]
@ -329,8 +319,7 @@ cdef class KnowledgeBase:
return 0.0 return 0.0
def dump(self, loc): def to_disk(self, loc):
self.require_vocab()
cdef Writer writer = Writer(loc) cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length) writer.write_header(self.get_size_entities(), self.entity_vector_length)
@ -370,7 +359,7 @@ cdef class KnowledgeBase:
writer.close() writer.close()
cpdef load_bulk(self, loc): cpdef from_disk(self, loc):
cdef hash_t entity_hash cdef hash_t entity_hash
cdef hash_t alias_hash cdef hash_t alias_hash
cdef int64_t entry_index cdef int64_t entry_index
@ -462,12 +451,11 @@ cdef class KnowledgeBase:
cdef class Writer: cdef class Writer:
def __init__(self, object loc): def __init__(self, object loc):
if path.exists(loc):
assert not path.isdir(loc), f"{loc} is directory"
if isinstance(loc, Path): if isinstance(loc, Path):
loc = bytes(loc) loc = bytes(loc)
if path.exists(loc): if path.exists(loc):
assert not path.isdir(loc), "%s is directory." % loc if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb') self._fp = fopen(<char*>bytes_loc, 'wb')
if not self._fp: if not self._fp:
@ -511,8 +499,10 @@ cdef class Reader:
def __init__(self, object loc): def __init__(self, object loc):
if isinstance(loc, Path): if isinstance(loc, Path):
loc = bytes(loc) loc = bytes(loc)
assert path.exists(loc) if not path.exists(loc):
assert not path.isdir(loc) raise ValueError(Errors.E929.format(loc=loc))
if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'rb') self._fp = fopen(<char*>bytes_loc, 'rb')
if not self._fp: if not self._fp:

View File

@ -772,9 +772,9 @@ class Language:
self.remove_pipe(name) self.remove_pipe(name)
if not len(self.pipeline) or pipe_index == len(self.pipeline): if not len(self.pipeline) or pipe_index == len(self.pipeline):
# we have no components to insert before/after, or we're replacing the last component # we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name) self.add_pipe(factory_name, name=name, config=config, validate=validate)
else: else:
self.add_pipe(factory_name, name=name, before=pipe_index) self.add_pipe(factory_name, name=name, before=pipe_index, config=config, validate=validate)
def rename_pipe(self, old_name: str, new_name: str) -> None: def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component. """Rename a pipeline component.

View File

@ -1,9 +1,9 @@
from typing import Optional from typing import Optional, Callable, Iterable
from thinc.api import chain, clone, list2ragged, reduce_mean, residual from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear from thinc.api import Model, Maxout, Linear
from ...util import registry from ...util import registry
from ...kb import KnowledgeBase from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab from ...vocab import Vocab
@ -25,15 +25,21 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
@registry.assets.register("spacy.KBFromFile.v1") @registry.assets.register("spacy.KBFromFile.v1")
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase: def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
vocab = Vocab().from_disk(vocab_path) def kb_from_file(vocab):
kb = KnowledgeBase(entity_vector_length=1) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.initialize(vocab) kb.from_disk(kb_path)
kb.load_bulk(kb_path) return kb
return kb return kb_from_file
@registry.assets.register("spacy.EmptyKB.v1") @registry.assets.register("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> KnowledgeBase: def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
kb = KnowledgeBase(entity_vector_length=entity_vector_length) def empty_kb_factory(vocab):
return kb return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory
@registry.assets.register("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_candidates

View File

@ -6,7 +6,7 @@ from thinc.api import CosineDistance, get_array_module, Model, Optimizer, Config
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate
import warnings import warnings
from ..kb import KnowledgeBase from ..kb import KnowledgeBase, Candidate
from ..tokens import Doc from ..tokens import Doc
from .pipe import Pipe, deserialize_config from .pipe import Pipe, deserialize_config
from ..language import Language from ..language import Language
@ -32,35 +32,30 @@ subword_features = true
""" """
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
default_kb_config = """
[kb]
@assets = "spacy.EmptyKB.v1"
entity_vector_length = 64
"""
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
@Language.factory( @Language.factory(
"entity_linker", "entity_linker",
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"], assigns=["token.ent_kb_id"],
default_config={ default_config={
"kb": DEFAULT_NEL_KB, "kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 64},
"model": DEFAULT_NEL_MODEL, "model": DEFAULT_NEL_MODEL,
"labels_discard": [], "labels_discard": [],
"incl_prior": True, "incl_prior": True,
"incl_context": True, "incl_context": True,
"get_candidates": {"@assets": "spacy.CandidateGenerator.v1"},
}, },
) )
def make_entity_linker( def make_entity_linker(
nlp: Language, nlp: Language,
name: str, name: str,
model: Model, model: Model,
kb: KnowledgeBase, kb_loader: Callable[[Vocab], KnowledgeBase],
*, *,
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
): ):
"""Construct an EntityLinker component. """Construct an EntityLinker component.
@ -76,10 +71,11 @@ def make_entity_linker(
nlp.vocab, nlp.vocab,
model, model,
name, name,
kb=kb, kb_loader=kb_loader,
labels_discard=labels_discard, labels_discard=labels_discard,
incl_prior=incl_prior, incl_prior=incl_prior,
incl_context=incl_context, incl_context=incl_context,
get_candidates=get_candidates,
) )
@ -97,10 +93,11 @@ class EntityLinker(Pipe):
model: Model, model: Model,
name: str = "entity_linker", name: str = "entity_linker",
*, *,
kb: KnowledgeBase, kb_loader: Callable[[Vocab], KnowledgeBase],
labels_discard: Iterable[str], labels_discard: Iterable[str],
incl_prior: bool, incl_prior: bool,
incl_context: bool, incl_context: bool,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
) -> None: ) -> None:
"""Initialize an entity linker. """Initialize an entity linker.
@ -108,7 +105,7 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component. model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases. kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model. incl_context (bool): Whether or not to include the local context in the model.
@ -119,17 +116,12 @@ class EntityLinker(Pipe):
self.model = model self.model = model
self.name = name self.name = name
cfg = { cfg = {
"kb": kb,
"labels_discard": list(labels_discard), "labels_discard": list(labels_discard),
"incl_prior": incl_prior, "incl_prior": incl_prior,
"incl_context": incl_context, "incl_context": incl_context,
} }
if not isinstance(kb, KnowledgeBase): self.kb = kb_loader(self.vocab)
raise ValueError(Errors.E990.format(type=type(self.kb))) self.get_candidates = get_candidates
kb.initialize(vocab)
self.kb = kb
if "kb" in cfg:
del cfg["kb"] # we don't want to duplicate its serialization
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False) self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account # how many neightbour sentences to take into account
@ -326,10 +318,11 @@ class EntityLinker(Pipe):
end_token = sentences[end_sentence].end end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc() sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined) # currently, the context is the same for each entity in a sentence (should be refined)
sentence_encoding = self.model.predict([sent_doc])[0] xp = self.model.ops.xp
xp = get_array_module(sentence_encoding) if self.cfg.get("incl_context"):
sentence_encoding_t = sentence_encoding.T sentence_encoding = self.model.predict([sent_doc])[0]
sentence_norm = xp.linalg.norm(sentence_encoding_t) sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent.ents: for ent in sent.ents:
entity_count += 1 entity_count += 1
to_discard = self.cfg.get("labels_discard", []) to_discard = self.cfg.get("labels_discard", [])
@ -337,7 +330,7 @@ class EntityLinker(Pipe):
# ignoring this entity - setting to NIL # ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
else: else:
candidates = self.kb.get_candidates(ent.text) candidates = self.get_candidates(self.kb, ent)
if not candidates: if not candidates:
# no prediction possible for this entity - setting to NIL # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL) final_kb_ids.append(self.NIL)
@ -421,10 +414,9 @@ class EntityLinker(Pipe):
DOCS: https://spacy.io/api/entitylinker#to_disk DOCS: https://spacy.io/api/entitylinker#to_disk
""" """
serialize = {} serialize = {}
self.cfg["entity_width"] = self.kb.entity_vector_length
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["kb"] = lambda p: self.kb.dump(p) serialize["kb"] = lambda p: self.kb.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p) serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -446,15 +438,10 @@ class EntityLinker(Pipe):
except AttributeError: except AttributeError:
raise ValueError(Errors.E149) from None raise ValueError(Errors.E149) from None
def load_kb(p):
self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
self.kb.initialize(self.vocab)
self.kb.load_bulk(p)
deserialize = {} deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p) deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p)) deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["kb"] = load_kb deserialize["kb"] = lambda p: self.kb.from_disk(p)
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self

View File

@ -68,7 +68,6 @@ class Tagger(Pipe):
name (str): The component instance name, used to add entries to the name (str): The component instance name, used to add entries to the
losses during training. losses during training.
labels (List): The set of labels. Defaults to None. labels (List): The set of labels. Defaults to None.
set_morphology (bool): Whether to set morphological features.
DOCS: https://spacy.io/api/tagger#init DOCS: https://spacy.io/api/tagger#init
""" """

View File

@ -1,6 +1,7 @@
from typing import Callable, Iterable
import pytest import pytest
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase, get_candidates, Candidate
from spacy import util, registry from spacy import util, registry
from spacy.gold import Example from spacy.gold import Example
@ -21,8 +22,7 @@ def assert_almost_equal(a, b):
def test_kb_valid_entities(nlp): def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases""" """Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(entity_vector_length=3) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
@ -51,8 +51,7 @@ def test_kb_valid_entities(nlp):
def test_kb_invalid_entities(nlp): def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity""" """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -68,8 +67,7 @@ def test_kb_invalid_entities(nlp):
def test_kb_invalid_probabilities(nlp): def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities""" """Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -83,8 +81,7 @@ def test_kb_invalid_probabilities(nlp):
def test_kb_invalid_combination(nlp): def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists""" """Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -100,8 +97,7 @@ def test_kb_invalid_combination(nlp):
def test_kb_invalid_entity_vector(nlp): def test_kb_invalid_entity_vector(nlp):
"""Test the invalid construction of a KB with non-matching entity vector lengths""" """Test the invalid construction of a KB with non-matching entity vector lengths"""
mykb = KnowledgeBase(entity_vector_length=3) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
@ -117,14 +113,14 @@ def test_kb_default(nlp):
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0 assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0 assert entity_linker.kb.get_size_aliases() == 0
# default value from pipeline.entity_linker # 64 is the default value from pipeline.entity_linker
assert entity_linker.kb.entity_vector_length == 64 assert entity_linker.kb.entity_vector_length == 64
def test_kb_custom_length(nlp): def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length""" """Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe( entity_linker = nlp.add_pipe(
"entity_linker", config={"kb": {"entity_vector_length": 35}} "entity_linker", config={"kb_loader": {"entity_vector_length": 35}}
) )
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0 assert entity_linker.kb.get_size_entities() == 0
@ -141,7 +137,7 @@ def test_kb_undefined(nlp):
def test_kb_empty(nlp): def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB""" """Test that the EL can't train with an empty KB"""
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}} config = {"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0 assert len(entity_linker.kb) == 0
with pytest.raises(ValueError): with pytest.raises(ValueError):
@ -150,8 +146,13 @@ def test_kb_empty(nlp):
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
"""Test correct candidate generation""" """Test correct candidate generation"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab) doc = nlp("douglas adam Adam shrubbery")
douglas_ent = doc[0:1]
adam_ent = doc[1:2]
Adam_ent = doc[2:3]
shrubbery_ent = doc[3:4]
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -163,21 +164,76 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2 assert len(get_candidates(mykb, douglas_ent)) == 2
assert len(mykb.get_candidates("adam")) == 1 assert len(get_candidates(mykb, adam_ent)) == 1
assert len(mykb.get_candidates("shrubbery")) == 0 assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates # test the content of the candidates
assert mykb.get_candidates("adam")[0].entity_ == "Q2" assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2"
assert mykb.get_candidates("adam")[0].alias_ == "adam" assert get_candidates(mykb, adam_ent)[0].alias_ == "adam"
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 12) assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp):
"""Test correct candidate generation as part of the EL pipe"""
nlp.add_pipe("sentencizer")
pattern = {"label": "PERSON", "pattern": [{"LOWER": "douglas"}]}
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns([pattern])
@registry.assets.register("myAdamKB.v1")
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
kb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
)
return kb
return create_kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only
nlp.add_pipe(
"entity_linker",
config={"kb_loader": {"@assets": "myAdamKB.v1"}, "incl_context": False},
)
# With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same."
doc = nlp(text)
assert doc[0].ent_kb_id_ == "NIL"
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower())
@registry.assets.register("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_lowercased_candidates
# replace the pipe with a new one with with a different candidate generator
nlp.replace_pipe(
"entity_linker",
"entity_linker",
config={
"kb_loader": {"@assets": "myAdamKB.v1"},
"incl_context": False,
"get_candidates": {"@assets": "spacy.LowercaseCandidateGenerator.v1"},
},
)
doc = nlp(text)
assert doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
def test_append_alias(nlp): def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs""" """Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -189,26 +245,25 @@ def test_append_alias(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2 assert len(mykb.get_alias_candidates("douglas")) == 2
# append an alias # append an alias
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2) mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
# test the size of the relevant candidates has been incremented # test the size of the relevant candidates has been incremented
assert len(mykb.get_candidates("douglas")) == 3 assert len(mykb.get_alias_candidates("douglas")) == 3
# append the same alias-entity pair again should not work (will throw a warning) # append the same alias-entity pair again should not work (will throw a warning)
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3) mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
# test the size of the relevant candidates remained unchanged # test the size of the relevant candidates remained unchanged
assert len(mykb.get_candidates("douglas")) == 3 assert len(mykb.get_alias_candidates("douglas")) == 3
def test_append_invalid_alias(nlp): def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1""" """Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -228,16 +283,18 @@ def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links""" """Test that Span.as_doc preserves the existing entity links"""
@registry.assets.register("myLocationsKB.v1") @registry.assets.register("myLocationsKB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
mykb = KnowledgeBase(entity_vector_length=1) def create_kb(vocab):
mykb.initialize(nlp.vocab) mykb = KnowledgeBase(vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
# adding aliases # adding aliases
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6]) mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
return mykb return mykb
return create_kb
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
nlp.add_pipe("sentencizer") nlp.add_pipe("sentencizer")
@ -247,7 +304,7 @@ def test_preserving_links_asdoc(nlp):
] ]
ruler = nlp.add_pipe("entity_ruler") ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
el_config = {"kb": {"@assets": "myLocationsKB.v1"}, "incl_prior": False} el_config = {"kb_loader": {"@assets": "myLocationsKB.v1"}, "incl_prior": False}
el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True) el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True)
el_pipe.begin_training(lambda: []) el_pipe.begin_training(lambda: [])
el_pipe.incl_context = False el_pipe.incl_context = False
@ -331,24 +388,28 @@ def test_overfitting_IO():
train_examples.append(Example.from_dict(doc, annotation)) train_examples.append(Example.from_dict(doc, annotation))
@registry.assets.register("myOverfittingKB.v1") @registry.assets.register("myOverfittingKB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
# create artificial KB - assign same prior weight to the two russ cochran's def create_kb(vocab):
# Q2146908 (Russ Cochran): American golfer # create artificial KB - assign same prior weight to the two russ cochran's
# Q7381115 (Russ Cochran): publisher # Q2146908 (Russ Cochran): American golfer
mykb = KnowledgeBase(entity_vector_length=3) # Q7381115 (Russ Cochran): publisher
mykb.initialize(nlp.vocab) mykb = KnowledgeBase(vocab, entity_vector_length=3)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias( mykb.add_alias(
alias="Russ Cochran", alias="Russ Cochran",
entities=["Q2146908", "Q7381115"], entities=["Q2146908", "Q7381115"],
probabilities=[0.5, 0.5], probabilities=[0.5, 0.5],
) )
return mykb return mykb
return create_kb
# Create the Entity Linker component and add it to the pipeline # Create the Entity Linker component and add it to the pipeline
nlp.add_pipe( nlp.add_pipe(
"entity_linker", config={"kb": {"@assets": "myOverfittingKB.v1"}}, last=True "entity_linker",
config={"kb_loader": {"@assets": "myOverfittingKB.v1"}},
last=True,
) )
# train the NEL pipe # train the NEL pipe

View File

@ -78,6 +78,14 @@ def test_replace_last_pipe(nlp):
assert nlp.pipe_names == ["sentencizer", "ner"] assert nlp.pipe_names == ["sentencizer", "ner"]
def test_replace_pipe_config(nlp):
nlp.add_pipe("entity_linker")
nlp.add_pipe("sentencizer")
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == True
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == False
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")]) @pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
def test_rename_pipe(nlp, old_name, new_name): def test_rename_pipe(nlp, old_name, new_name):
with pytest.raises(ValueError): with pytest.raises(ValueError):

View File

@ -139,8 +139,7 @@ def test_issue4665():
def test_issue4674(): def test_issue4674():
"""Test that setting entities with overlapping identifiers does not mess up IO""" """Test that setting entities with overlapping identifiers does not mess up IO"""
nlp = English() nlp = English()
kb = KnowledgeBase(entity_vector_length=3) kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb.initialize(nlp.vocab)
vector1 = [0.9, 1.1, 1.01] vector1 = [0.9, 1.1, 1.01]
vector2 = [1.8, 2.25, 2.01] vector2 = [1.8, 2.25, 2.01]
with pytest.warns(UserWarning): with pytest.warns(UserWarning):
@ -156,10 +155,9 @@ def test_issue4674():
if not dir_path.exists(): if not dir_path.exists():
dir_path.mkdir() dir_path.mkdir()
file_path = dir_path / "kb" file_path = dir_path / "kb"
kb.dump(str(file_path)) kb.to_disk(str(file_path))
kb2 = KnowledgeBase(entity_vector_length=3) kb2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb2.initialize(nlp.vocab) kb2.from_disk(str(file_path))
kb2.load_bulk(str(file_path))
assert kb2.get_size_entities() == 1 assert kb2.get_size_entities() == 1

View File

@ -1,3 +1,4 @@
from typing import Callable
import warnings import warnings
from unittest import TestCase from unittest import TestCase
import pytest import pytest
@ -70,13 +71,14 @@ def entity_linker():
nlp = Language() nlp = Language()
@registry.assets.register("TestIssue5230KB.v1") @registry.assets.register("TestIssue5230KB.v1")
def dummy_kb() -> KnowledgeBase: def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
kb = KnowledgeBase(entity_vector_length=1) def create_kb(vocab):
kb.initialize(nlp.vocab) kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb return kb
return create_kb
config = {"kb": {"@assets": "TestIssue5230KB.v1"}} config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config) entity_linker = nlp.add_pipe("entity_linker", config=config)
# need to add model for two reasons: # need to add model for two reasons:
# 1. no model leads to error in serialization, # 1. no model leads to error in serialization,
@ -121,19 +123,17 @@ def test_writer_with_path_py35():
def test_save_and_load_knowledge_base(): def test_save_and_load_knowledge_base():
nlp = Language() nlp = Language()
kb = KnowledgeBase(entity_vector_length=1) kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb.initialize(nlp.vocab)
with make_tempdir() as d: with make_tempdir() as d:
path = d / "kb" path = d / "kb"
try: try:
kb.dump(path) kb.to_disk(path)
except Exception as e: except Exception as e:
pytest.fail(str(e)) pytest.fail(str(e))
try: try:
kb_loaded = KnowledgeBase(entity_vector_length=1) kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb_loaded.initialize(nlp.vocab) kb_loaded.from_disk(path)
kb_loaded.load_bulk(path)
except Exception as e: except Exception as e:
pytest.fail(str(e)) pytest.fail(str(e))

View File

@ -1,4 +1,8 @@
from spacy.util import ensure_path from typing import Callable
from spacy import util
from spacy.lang.en import English
from spacy.util import ensure_path, registry
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from ..util import make_tempdir from ..util import make_tempdir
@ -15,20 +19,16 @@ def test_serialize_kb_disk(en_vocab):
if not dir_path.exists(): if not dir_path.exists():
dir_path.mkdir() dir_path.mkdir()
file_path = dir_path / "kb" file_path = dir_path / "kb"
kb1.dump(str(file_path)) kb1.to_disk(str(file_path))
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
kb2 = KnowledgeBase(entity_vector_length=3) kb2.from_disk(str(file_path))
kb2.initialize(en_vocab)
kb2.load_bulk(str(file_path))
# final assertions # final assertions
_check_kb(kb2) _check_kb(kb2)
def _get_dummy_kb(vocab): def _get_dummy_kb(vocab):
kb = KnowledgeBase(entity_vector_length=3) kb = KnowledgeBase(vocab, entity_vector_length=3)
kb.initialize(vocab)
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3]) kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0]) kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])
kb.add_entity(entity="Q007", freq=7, entity_vector=[0, 0, 7]) kb.add_entity(entity="Q007", freq=7, entity_vector=[0, 0, 7])
@ -61,7 +61,7 @@ def _check_kb(kb):
assert alias_string not in kb.get_alias_strings() assert alias_string not in kb.get_alias_strings()
# check candidates & probabilities # check candidates & probabilities
candidates = sorted(kb.get_candidates("double07"), key=lambda x: x.entity_) candidates = sorted(kb.get_alias_candidates("double07"), key=lambda x: x.entity_)
assert len(candidates) == 2 assert len(candidates) == 2
assert candidates[0].entity_ == "Q007" assert candidates[0].entity_ == "Q007"
@ -75,3 +75,47 @@ def _check_kb(kb):
assert candidates[1].entity_vector == [7, 1, 0] assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].alias_ == "double07" assert candidates[1].alias_ == "double07"
assert 0.099 < candidates[1].prior_prob < 0.101 assert 0.099 < candidates[1].prior_prob < 0.101
def test_serialize_subclassed_kb():
"""Check that IO of a custom KB works fine as part of an EL pipe."""
class SubKnowledgeBase(KnowledgeBase):
def __init__(self, vocab, entity_vector_length, custom_field):
super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field
@registry.assets.register("spacy.CustomKB.v1")
def custom_kb(
entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]:
def custom_kb_factory(vocab):
return SubKnowledgeBase(
vocab=vocab,
entity_vector_length=entity_vector_length,
custom_field=custom_field,
)
return custom_kb_factory
nlp = English()
config = {
"kb_loader": {
"@assets": "spacy.CustomKB.v1",
"entity_vector_length": 342,
"custom_field": 666,
}
}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert type(entity_linker.kb) == SubKnowledgeBase
assert entity_linker.kb.entity_vector_length == 342
assert entity_linker.kb.custom_field == 666
# Make sure the custom KB is serialized correctly
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
entity_linker2 = nlp2.get_pipe("entity_linker")
assert type(entity_linker2.kb) == SubKnowledgeBase
assert entity_linker2.kb.entity_vector_length == 342
assert entity_linker2.kb.custom_field == 666

View File

@ -200,21 +200,21 @@ probability of the fact that the mention links to the entity ID.
| `alias` | The textual mention or alias. ~~str~~ | | `alias` | The textual mention or alias. ~~str~~ |
| **RETURNS** | The prior probability of the `alias` referring to the `entity`. ~~float~~ | | **RETURNS** | The prior probability of the `alias` referring to the `entity`. ~~float~~ |
## KnowledgeBase.dump {#dump tag="method"} ## KnowledgeBase.to_disk {#to_disk tag="method"}
Save the current state of the knowledge base to a directory. Save the current state of the knowledge base to a directory.
> #### Example > #### Example
> >
> ```python > ```python
> kb.dump(loc) > kb.to_disk(loc)
> ``` > ```
| Name | Description | | Name | Description |
| ----- | ------------------------------------------------------------------------------------------------------------------------------------------ | | ----- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| `loc` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ | | `loc` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
## KnowledgeBase.load_bulk {#load_bulk tag="method"} ## KnowledgeBase.from_disk {#from_disk tag="method"}
Restore the state of the knowledge base from a given directory. Note that the Restore the state of the knowledge base from a given directory. Note that the
[`Vocab`](/api/vocab) should also be the same as the one used to create the KB. [`Vocab`](/api/vocab) should also be the same as the one used to create the KB.
@ -226,7 +226,7 @@ Restore the state of the knowledge base from a given directory. Note that the
> from spacy.vocab import Vocab > from spacy.vocab import Vocab
> vocab = Vocab().from_disk("/path/to/vocab") > vocab = Vocab().from_disk("/path/to/vocab")
> kb = KnowledgeBase(vocab=vocab, entity_vector_length=64) > kb = KnowledgeBase(vocab=vocab, entity_vector_length=64)
> kb.load_bulk("/path/to/kb") > kb.from_disk("/path/to/kb")
> ``` > ```
| Name | Description | | Name | Description |