Define candidate generator in EL config (#5876)

* candidate generator as separate part of EL config

* update comment

* ent instead of str as input for candidate generation

* Span instead of str: correct type indication

* fix types

* unit test to create new candidate generator

* fix replace_pipe argument passing

* move error message, general cleanup

* add vocab back to KB constructor

* provide KB as callable from Vocab arg

* rename to kb_loader, fix KB serialization as part of the EL pipe

* fix typo

* reformatting

* cleanup

* fix comment

* fix wrongly duplicated code from merge conflict

* rename dump to to_disk

* from_disk instead of load_bulk

* update test after recent removal of set_morphology in tagger

* remove old doc
This commit is contained in:
Sofie Van Landeghem 2020-08-18 16:10:36 +02:00 committed by GitHub
parent 688e77562b
commit 358cbb21e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 272 additions and 180 deletions

View File

@ -15,7 +15,8 @@ import spacy.util
from bin.ud import conll17_ud_eval
from spacy.tokens import Token, Doc
from spacy.gold import Example
from spacy.util import compounding, minibatch, minibatch_by_words
from spacy.util import compounding, minibatch
from spacy.gold.batchers import minibatch_by_words
from spacy.pipeline._parser_internals.nonproj import projectivize
from spacy.matcher import Matcher
from spacy import displacy

View File

@ -48,8 +48,7 @@ def main(model, output_dir=None):
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
# For simplicity, we'll just use the original vector dimension here instead.
vectors_dim = nlp.vocab.vectors.shape[1]
kb = KnowledgeBase(entity_vector_length=vectors_dim)
kb.initialize(nlp.vocab)
kb = KnowledgeBase(nlp.vocab, entity_vector_length=vectors_dim)
# set up the data
entity_ids = []
@ -81,7 +80,7 @@ def main(model, output_dir=None):
if not output_dir.exists():
output_dir.mkdir()
kb_path = str(output_dir / "kb")
kb.dump(kb_path)
kb.to_disk(kb_path)
print()
print("Saved KB to", kb_path)
@ -96,9 +95,8 @@ def main(model, output_dir=None):
print("Loading vocab from", vocab_path)
print("Loading KB from", kb_path)
vocab2 = Vocab().from_disk(vocab_path)
kb2 = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab2)
kb2.load_bulk(kb_path)
kb2 = KnowledgeBase(vocab2, entity_vector_length=1)
kb2.from_disk(kb_path)
print()
_print_kb(kb2)

View File

@ -83,7 +83,7 @@ def main(kb_path, vocab_path, output_dir=None, n_iter=50):
if "entity_linker" not in nlp.pipe_names:
print("Loading Knowledge Base from '%s'" % kb_path)
cfg = {
"kb": {
"kb_loader": {
"@assets": "spacy.KBFromFile.v1",
"vocab_path": vocab_path,
"kb_path": kb_path,

View File

@ -477,6 +477,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master
E928 = ("A 'KnowledgeBase' should be written to / read from a file, but the "
"provided argument {loc} is an existing directory.")
E929 = ("A 'KnowledgeBase' could not be read from {loc} - the path does "
"not seem to exist.")
E930 = ("Received invalid get_examples callback in {name}.begin_training. "
"Expected function that returns an iterable of Example objects but "
"got: {obj}")
@ -504,8 +508,6 @@ class Errors:
"not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
"nlp object, but got: {source}")
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
@ -612,8 +614,6 @@ class Errors:
"of the training data in spaCy 3.0 onwards. The 'update' "
"function should now be called with a batch of 'Example' "
"objects, instead of (text, annotation) tuples. ")
E990 = ("An entity linking component needs to be initialized with a "
"KnowledgeBase object, but found {type} instead.")
E991 = ("The function 'select_pipes' should be called with either a "
"'disable' argument to list the names of the pipe components "
"that should be disabled, or with an 'enable' argument that "

View File

@ -140,7 +140,7 @@ cdef class KnowledgeBase:
self._entries.push_back(entry)
self._aliases_table.push_back(alias)
cpdef load_bulk(self, loc)
cpdef from_disk(self, loc)
cpdef set_entities(self, entity_list, freq_list, vector_list)

View File

@ -1,4 +1,5 @@
# cython: infer_types=True, profile=True
from typing import Iterator
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
from cpython.exc cimport PyErr_SetFromErrno
@ -64,6 +65,16 @@ cdef class Candidate:
return self.prior_prob
def get_candidates(KnowledgeBase kb, span) -> Iterator[Candidate]:
"""
Return candidate entities for a given span by using the text of the span as the alias
and fetching appropriate entries from the index.
This particular function is optimized to work with the built-in KB functionality,
but any other custom candidate generation method can be used in combination with the KB as well.
"""
return kb.get_alias_candidates(span.text)
cdef class KnowledgeBase:
"""A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases,
to support entity linking of named entities to real-world concepts.
@ -71,25 +82,16 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb
"""
def __init__(self, entity_vector_length):
"""Create a KnowledgeBase. Make sure to call kb.initialize() before using it."""
def __init__(self, Vocab vocab, entity_vector_length):
"""Create a KnowledgeBase."""
self.mem = Pool()
self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap()
self._alias_index = PreshMap()
self.vocab = None
def initialize(self, Vocab vocab):
self.vocab = vocab
self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
def require_vocab(self):
if self.vocab is None:
raise ValueError(Errors.E946)
@property
def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors"""
@ -102,14 +104,12 @@ cdef class KnowledgeBase:
return len(self._entry_index)
def get_entity_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._entry_index]
def get_size_aliases(self):
return len(self._alias_index)
def get_alias_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
@ -117,7 +117,6 @@ cdef class KnowledgeBase:
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end.
"""
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before
@ -140,7 +139,6 @@ cdef class KnowledgeBase:
return entity_hash
cpdef set_entities(self, entity_list, freq_list, vector_list):
self.require_vocab()
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140)
@ -176,12 +174,10 @@ cdef class KnowledgeBase:
i += 1
def contains_entity(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index
def contains_alias(self, unicode alias):
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index
@ -190,7 +186,6 @@ cdef class KnowledgeBase:
For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end
"""
self.require_vocab()
# Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias,
@ -234,7 +229,6 @@ cdef class KnowledgeBase:
Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
"""
self.require_vocab()
# Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
@ -274,14 +268,12 @@ cdef class KnowledgeBase:
alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry
def get_candidates(self, unicode alias):
def get_alias_candidates(self, unicode alias) -> Iterator[Candidate]:
"""
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned.
"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
return []
@ -298,7 +290,6 @@ cdef class KnowledgeBase:
if entry_index != 0]
def get_vector(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB
@ -311,7 +302,6 @@ cdef class KnowledgeBase:
def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity]
@ -329,8 +319,7 @@ cdef class KnowledgeBase:
return 0.0
def dump(self, loc):
self.require_vocab()
def to_disk(self, loc):
cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length)
@ -370,7 +359,7 @@ cdef class KnowledgeBase:
writer.close()
cpdef load_bulk(self, loc):
cpdef from_disk(self, loc):
cdef hash_t entity_hash
cdef hash_t alias_hash
cdef int64_t entry_index
@ -462,12 +451,11 @@ cdef class KnowledgeBase:
cdef class Writer:
def __init__(self, object loc):
if path.exists(loc):
assert not path.isdir(loc), f"{loc} is directory"
if isinstance(loc, Path):
loc = bytes(loc)
if path.exists(loc):
assert not path.isdir(loc), "%s is directory." % loc
if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb')
if not self._fp:
@ -511,8 +499,10 @@ cdef class Reader:
def __init__(self, object loc):
if isinstance(loc, Path):
loc = bytes(loc)
assert path.exists(loc)
assert not path.isdir(loc)
if not path.exists(loc):
raise ValueError(Errors.E929.format(loc=loc))
if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'rb')
if not self._fp:

View File

@ -772,9 +772,9 @@ class Language:
self.remove_pipe(name)
if not len(self.pipeline) or pipe_index == len(self.pipeline):
# we have no components to insert before/after, or we're replacing the last component
self.add_pipe(factory_name, name=name)
self.add_pipe(factory_name, name=name, config=config, validate=validate)
else:
self.add_pipe(factory_name, name=name, before=pipe_index)
self.add_pipe(factory_name, name=name, before=pipe_index, config=config, validate=validate)
def rename_pipe(self, old_name: str, new_name: str) -> None:
"""Rename a pipeline component.

View File

@ -1,9 +1,9 @@
from typing import Optional
from typing import Optional, Callable, Iterable
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear
from ...util import registry
from ...kb import KnowledgeBase
from ...kb import KnowledgeBase, Candidate, get_candidates
from ...vocab import Vocab
@ -25,15 +25,21 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
@registry.assets.register("spacy.KBFromFile.v1")
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
vocab = Vocab().from_disk(vocab_path)
kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab)
kb.load_bulk(kb_path)
def load_kb(kb_path: str) -> Callable[[Vocab], KnowledgeBase]:
def kb_from_file(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.from_disk(kb_path)
return kb
return kb_from_file
@registry.assets.register("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> KnowledgeBase:
kb = KnowledgeBase(entity_vector_length=entity_vector_length)
return kb
def empty_kb(entity_vector_length: int) -> Callable[[Vocab], KnowledgeBase]:
def empty_kb_factory(vocab):
return KnowledgeBase(vocab=vocab, entity_vector_length=entity_vector_length)
return empty_kb_factory
@registry.assets.register("spacy.CandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_candidates

View File

@ -6,7 +6,7 @@ from thinc.api import CosineDistance, get_array_module, Model, Optimizer, Config
from thinc.api import set_dropout_rate
import warnings
from ..kb import KnowledgeBase
from ..kb import KnowledgeBase, Candidate
from ..tokens import Doc
from .pipe import Pipe, deserialize_config
from ..language import Language
@ -32,35 +32,30 @@ subword_features = true
"""
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
default_kb_config = """
[kb]
@assets = "spacy.EmptyKB.v1"
entity_vector_length = 64
"""
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
@Language.factory(
"entity_linker",
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"],
default_config={
"kb": DEFAULT_NEL_KB,
"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 64},
"model": DEFAULT_NEL_MODEL,
"labels_discard": [],
"incl_prior": True,
"incl_context": True,
"get_candidates": {"@assets": "spacy.CandidateGenerator.v1"},
},
)
def make_entity_linker(
nlp: Language,
name: str,
model: Model,
kb: KnowledgeBase,
kb_loader: Callable[[Vocab], KnowledgeBase],
*,
labels_discard: Iterable[str],
incl_prior: bool,
incl_context: bool,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
):
"""Construct an EntityLinker component.
@ -76,10 +71,11 @@ def make_entity_linker(
nlp.vocab,
model,
name,
kb=kb,
kb_loader=kb_loader,
labels_discard=labels_discard,
incl_prior=incl_prior,
incl_context=incl_context,
get_candidates=get_candidates,
)
@ -97,10 +93,11 @@ class EntityLinker(Pipe):
model: Model,
name: str = "entity_linker",
*,
kb: KnowledgeBase,
kb_loader: Callable[[Vocab], KnowledgeBase],
labels_discard: Iterable[str],
incl_prior: bool,
incl_context: bool,
get_candidates: Callable[[KnowledgeBase, "Span"], Iterable[Candidate]],
) -> None:
"""Initialize an entity linker.
@ -108,7 +105,7 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases.
kb_loader (Callable[[Vocab], KnowledgeBase]): A function that creates a KnowledgeBase from a Vocab instance.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
@ -119,17 +116,12 @@ class EntityLinker(Pipe):
self.model = model
self.name = name
cfg = {
"kb": kb,
"labels_discard": list(labels_discard),
"incl_prior": incl_prior,
"incl_context": incl_context,
}
if not isinstance(kb, KnowledgeBase):
raise ValueError(Errors.E990.format(type=type(self.kb)))
kb.initialize(vocab)
self.kb = kb
if "kb" in cfg:
del cfg["kb"] # we don't want to duplicate its serialization
self.kb = kb_loader(self.vocab)
self.get_candidates = get_candidates
self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account
@ -326,8 +318,9 @@ class EntityLinker(Pipe):
end_token = sentences[end_sentence].end
sent_doc = doc[start_token:end_token].as_doc()
# currently, the context is the same for each entity in a sentence (should be refined)
xp = self.model.ops.xp
if self.cfg.get("incl_context"):
sentence_encoding = self.model.predict([sent_doc])[0]
xp = get_array_module(sentence_encoding)
sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
for ent in sent.ents:
@ -337,7 +330,7 @@ class EntityLinker(Pipe):
# ignoring this entity - setting to NIL
final_kb_ids.append(self.NIL)
else:
candidates = self.kb.get_candidates(ent.text)
candidates = self.get_candidates(self.kb, ent)
if not candidates:
# no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
@ -421,10 +414,9 @@ class EntityLinker(Pipe):
DOCS: https://spacy.io/api/entitylinker#to_disk
"""
serialize = {}
self.cfg["entity_width"] = self.kb.entity_vector_length
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["kb"] = lambda p: self.kb.dump(p)
serialize["kb"] = lambda p: self.kb.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude)
@ -446,15 +438,10 @@ class EntityLinker(Pipe):
except AttributeError:
raise ValueError(Errors.E149) from None
def load_kb(p):
self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
self.kb.initialize(self.vocab)
self.kb.load_bulk(p)
deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["kb"] = load_kb
deserialize["kb"] = lambda p: self.kb.from_disk(p)
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)
return self

View File

@ -68,7 +68,6 @@ class Tagger(Pipe):
name (str): The component instance name, used to add entries to the
losses during training.
labels (List): The set of labels. Defaults to None.
set_morphology (bool): Whether to set morphological features.
DOCS: https://spacy.io/api/tagger#init
"""

View File

@ -1,6 +1,7 @@
from typing import Callable, Iterable
import pytest
from spacy.kb import KnowledgeBase
from spacy.kb import KnowledgeBase, get_candidates, Candidate
from spacy import util, registry
from spacy.gold import Example
@ -21,8 +22,7 @@ def assert_almost_equal(a, b):
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
@ -51,8 +51,7 @@ def test_kb_valid_entities(nlp):
def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -68,8 +67,7 @@ def test_kb_invalid_entities(nlp):
def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -83,8 +81,7 @@ def test_kb_invalid_probabilities(nlp):
def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -100,8 +97,7 @@ def test_kb_invalid_combination(nlp):
def test_kb_invalid_entity_vector(nlp):
"""Test the invalid construction of a KB with non-matching entity vector lengths"""
mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
@ -117,14 +113,14 @@ def test_kb_default(nlp):
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
# default value from pipeline.entity_linker
# 64 is the default value from pipeline.entity_linker
assert entity_linker.kb.entity_vector_length == 64
def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe(
"entity_linker", config={"kb": {"entity_vector_length": 35}}
"entity_linker", config={"kb_loader": {"entity_vector_length": 35}}
)
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
@ -141,7 +137,7 @@ def test_kb_undefined(nlp):
def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB"""
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
config = {"kb_loader": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0
with pytest.raises(ValueError):
@ -150,8 +146,13 @@ def test_kb_empty(nlp):
def test_candidate_generation(nlp):
"""Test correct candidate generation"""
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
doc = nlp("douglas adam Adam shrubbery")
douglas_ent = doc[0:1]
adam_ent = doc[1:2]
Adam_ent = doc[2:3]
shrubbery_ent = doc[3:4]
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -163,21 +164,76 @@ def test_candidate_generation(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2
assert len(mykb.get_candidates("adam")) == 1
assert len(mykb.get_candidates("shrubbery")) == 0
assert len(get_candidates(mykb, douglas_ent)) == 2
assert len(get_candidates(mykb, adam_ent)) == 1
assert len(get_candidates(mykb, Adam_ent)) == 0 # default case sensitive
assert len(get_candidates(mykb, shrubbery_ent)) == 0
# test the content of the candidates
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
assert mykb.get_candidates("adam")[0].alias_ == "adam"
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 12)
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
assert get_candidates(mykb, adam_ent)[0].entity_ == "Q2"
assert get_candidates(mykb, adam_ent)[0].alias_ == "adam"
assert_almost_equal(get_candidates(mykb, adam_ent)[0].entity_freq, 12)
assert_almost_equal(get_candidates(mykb, adam_ent)[0].prior_prob, 0.9)
def test_el_pipe_configuration(nlp):
"""Test correct candidate generation as part of the EL pipe"""
nlp.add_pipe("sentencizer")
pattern = {"label": "PERSON", "pattern": [{"LOWER": "douglas"}]}
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns([pattern])
@registry.assets.register("myAdamKB.v1")
def mykb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity(entity="Q2", freq=12, entity_vector=[2])
kb.add_entity(entity="Q3", freq=5, entity_vector=[3])
kb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]
)
return kb
return create_kb
# run an EL pipe without a trained context encoder, to check the candidate generation step only
nlp.add_pipe(
"entity_linker",
config={"kb_loader": {"@assets": "myAdamKB.v1"}, "incl_context": False},
)
# With the default get_candidates function, matching is case-sensitive
text = "Douglas and douglas are not the same."
doc = nlp(text)
assert doc[0].ent_kb_id_ == "NIL"
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
def get_lowercased_candidates(kb, span):
return kb.get_alias_candidates(span.text.lower())
@registry.assets.register("spacy.LowercaseCandidateGenerator.v1")
def create_candidates() -> Callable[[KnowledgeBase, "Span"], Iterable[Candidate]]:
return get_lowercased_candidates
# replace the pipe with a new one with with a different candidate generator
nlp.replace_pipe(
"entity_linker",
"entity_linker",
config={
"kb_loader": {"@assets": "myAdamKB.v1"},
"incl_context": False,
"get_candidates": {"@assets": "spacy.LowercaseCandidateGenerator.v1"},
},
)
doc = nlp(text)
assert doc[0].ent_kb_id_ == "Q2"
assert doc[1].ent_kb_id_ == ""
assert doc[2].ent_kb_id_ == "Q2"
def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -189,26 +245,25 @@ def test_append_alias(nlp):
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2
assert len(mykb.get_alias_candidates("douglas")) == 2
# append an alias
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
# test the size of the relevant candidates has been incremented
assert len(mykb.get_candidates("douglas")) == 3
assert len(mykb.get_alias_candidates("douglas")) == 3
# append the same alias-entity pair again should not work (will throw a warning)
with pytest.warns(UserWarning):
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
# test the size of the relevant candidates remained unchanged
assert len(mykb.get_candidates("douglas")) == 3
assert len(mykb.get_alias_candidates("douglas")) == 3
def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -228,9 +283,9 @@ def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links"""
@registry.assets.register("myLocationsKB.v1")
def dummy_kb() -> KnowledgeBase:
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
mykb = KnowledgeBase(vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
@ -239,6 +294,8 @@ def test_preserving_links_asdoc(nlp):
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
return mykb
return create_kb
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
nlp.add_pipe("sentencizer")
patterns = [
@ -247,7 +304,7 @@ def test_preserving_links_asdoc(nlp):
]
ruler = nlp.add_pipe("entity_ruler")
ruler.add_patterns(patterns)
el_config = {"kb": {"@assets": "myLocationsKB.v1"}, "incl_prior": False}
el_config = {"kb_loader": {"@assets": "myLocationsKB.v1"}, "incl_prior": False}
el_pipe = nlp.add_pipe("entity_linker", config=el_config, last=True)
el_pipe.begin_training(lambda: [])
el_pipe.incl_context = False
@ -331,12 +388,12 @@ def test_overfitting_IO():
train_examples.append(Example.from_dict(doc, annotation))
@registry.assets.register("myOverfittingKB.v1")
def dummy_kb() -> KnowledgeBase:
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
# create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
mykb = KnowledgeBase(vocab, entity_vector_length=3)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(
@ -346,9 +403,13 @@ def test_overfitting_IO():
)
return mykb
return create_kb
# Create the Entity Linker component and add it to the pipeline
nlp.add_pipe(
"entity_linker", config={"kb": {"@assets": "myOverfittingKB.v1"}}, last=True
"entity_linker",
config={"kb_loader": {"@assets": "myOverfittingKB.v1"}},
last=True,
)
# train the NEL pipe

View File

@ -78,6 +78,14 @@ def test_replace_last_pipe(nlp):
assert nlp.pipe_names == ["sentencizer", "ner"]
def test_replace_pipe_config(nlp):
nlp.add_pipe("entity_linker")
nlp.add_pipe("sentencizer")
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == True
nlp.replace_pipe("entity_linker", "entity_linker", config={"incl_prior": False})
assert nlp.get_pipe("entity_linker").cfg["incl_prior"] == False
@pytest.mark.parametrize("old_name,new_name", [("old_pipe", "new_pipe")])
def test_rename_pipe(nlp, old_name, new_name):
with pytest.raises(ValueError):

View File

@ -139,8 +139,7 @@ def test_issue4665():
def test_issue4674():
"""Test that setting entities with overlapping identifiers does not mess up IO"""
nlp = English()
kb = KnowledgeBase(entity_vector_length=3)
kb.initialize(nlp.vocab)
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
vector1 = [0.9, 1.1, 1.01]
vector2 = [1.8, 2.25, 2.01]
with pytest.warns(UserWarning):
@ -156,10 +155,9 @@ def test_issue4674():
if not dir_path.exists():
dir_path.mkdir()
file_path = dir_path / "kb"
kb.dump(str(file_path))
kb2 = KnowledgeBase(entity_vector_length=3)
kb2.initialize(nlp.vocab)
kb2.load_bulk(str(file_path))
kb.to_disk(str(file_path))
kb2 = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb2.from_disk(str(file_path))
assert kb2.get_size_entities() == 1

View File

@ -1,3 +1,4 @@
from typing import Callable
import warnings
from unittest import TestCase
import pytest
@ -70,13 +71,14 @@ def entity_linker():
nlp = Language()
@registry.assets.register("TestIssue5230KB.v1")
def dummy_kb() -> KnowledgeBase:
kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
def dummy_kb() -> Callable[["Vocab"], KnowledgeBase]:
def create_kb(vocab):
kb = KnowledgeBase(vocab, entity_vector_length=1)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb
return create_kb
config = {"kb": {"@assets": "TestIssue5230KB.v1"}}
config = {"kb_loader": {"@assets": "TestIssue5230KB.v1"}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
# need to add model for two reasons:
# 1. no model leads to error in serialization,
@ -121,19 +123,17 @@ def test_writer_with_path_py35():
def test_save_and_load_knowledge_base():
nlp = Language()
kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
with make_tempdir() as d:
path = d / "kb"
try:
kb.dump(path)
kb.to_disk(path)
except Exception as e:
pytest.fail(str(e))
try:
kb_loaded = KnowledgeBase(entity_vector_length=1)
kb_loaded.initialize(nlp.vocab)
kb_loaded.load_bulk(path)
kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb_loaded.from_disk(path)
except Exception as e:
pytest.fail(str(e))

View File

@ -1,4 +1,8 @@
from spacy.util import ensure_path
from typing import Callable
from spacy import util
from spacy.lang.en import English
from spacy.util import ensure_path, registry
from spacy.kb import KnowledgeBase
from ..util import make_tempdir
@ -15,20 +19,16 @@ def test_serialize_kb_disk(en_vocab):
if not dir_path.exists():
dir_path.mkdir()
file_path = dir_path / "kb"
kb1.dump(str(file_path))
kb2 = KnowledgeBase(entity_vector_length=3)
kb2.initialize(en_vocab)
kb2.load_bulk(str(file_path))
kb1.to_disk(str(file_path))
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
kb2.from_disk(str(file_path))
# final assertions
_check_kb(kb2)
def _get_dummy_kb(vocab):
kb = KnowledgeBase(entity_vector_length=3)
kb.initialize(vocab)
kb = KnowledgeBase(vocab, entity_vector_length=3)
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])
kb.add_entity(entity="Q007", freq=7, entity_vector=[0, 0, 7])
@ -61,7 +61,7 @@ def _check_kb(kb):
assert alias_string not in kb.get_alias_strings()
# check candidates & probabilities
candidates = sorted(kb.get_candidates("double07"), key=lambda x: x.entity_)
candidates = sorted(kb.get_alias_candidates("double07"), key=lambda x: x.entity_)
assert len(candidates) == 2
assert candidates[0].entity_ == "Q007"
@ -75,3 +75,47 @@ def _check_kb(kb):
assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].alias_ == "double07"
assert 0.099 < candidates[1].prior_prob < 0.101
def test_serialize_subclassed_kb():
"""Check that IO of a custom KB works fine as part of an EL pipe."""
class SubKnowledgeBase(KnowledgeBase):
def __init__(self, vocab, entity_vector_length, custom_field):
super().__init__(vocab, entity_vector_length)
self.custom_field = custom_field
@registry.assets.register("spacy.CustomKB.v1")
def custom_kb(
entity_vector_length: int, custom_field: int
) -> Callable[["Vocab"], KnowledgeBase]:
def custom_kb_factory(vocab):
return SubKnowledgeBase(
vocab=vocab,
entity_vector_length=entity_vector_length,
custom_field=custom_field,
)
return custom_kb_factory
nlp = English()
config = {
"kb_loader": {
"@assets": "spacy.CustomKB.v1",
"entity_vector_length": 342,
"custom_field": 666,
}
}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert type(entity_linker.kb) == SubKnowledgeBase
assert entity_linker.kb.entity_vector_length == 342
assert entity_linker.kb.custom_field == 666
# Make sure the custom KB is serialized correctly
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
entity_linker2 = nlp2.get_pipe("entity_linker")
assert type(entity_linker2.kb) == SubKnowledgeBase
assert entity_linker2.kb.entity_vector_length == 342
assert entity_linker2.kb.custom_field == 666

View File

@ -200,21 +200,21 @@ probability of the fact that the mention links to the entity ID.
| `alias` | The textual mention or alias. ~~str~~ |
| **RETURNS** | The prior probability of the `alias` referring to the `entity`. ~~float~~ |
## KnowledgeBase.dump {#dump tag="method"}
## KnowledgeBase.to_disk {#to_disk tag="method"}
Save the current state of the knowledge base to a directory.
> #### Example
>
> ```python
> kb.dump(loc)
> kb.to_disk(loc)
> ```
| Name | Description |
| ----- | ------------------------------------------------------------------------------------------------------------------------------------------ |
| `loc` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
## KnowledgeBase.load_bulk {#load_bulk tag="method"}
## KnowledgeBase.from_disk {#from_disk tag="method"}
Restore the state of the knowledge base from a given directory. Note that the
[`Vocab`](/api/vocab) should also be the same as the one used to create the KB.
@ -226,7 +226,7 @@ Restore the state of the knowledge base from a given directory. Note that the
> from spacy.vocab import Vocab
> vocab = Vocab().from_disk("/path/to/vocab")
> kb = KnowledgeBase(vocab=vocab, entity_vector_length=64)
> kb.load_bulk("/path/to/kb")
> kb.from_disk("/path/to/kb")
> ```
| Name | Description |