Default empty KB in EL component (#5872)

* EL field documentation

* documentation consistent with docs

* default empty KB, initialize vocab separately

* formatting

* add test for changing the default entity vector length

* update comment
This commit is contained in:
Sofie Van Landeghem 2020-08-04 14:34:09 +02:00 committed by GitHub
parent b7e3018d97
commit 82347110f5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 127 additions and 40 deletions

View File

@ -48,7 +48,8 @@ def main(model, output_dir=None):
# You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality.
# For simplicity, we'll just use the original vector dimension here instead.
vectors_dim = nlp.vocab.vectors.shape[1]
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=vectors_dim)
kb = KnowledgeBase(entity_vector_length=vectors_dim)
kb.initialize(nlp.vocab)
# set up the data
entity_ids = []
@ -95,7 +96,8 @@ def main(model, output_dir=None):
print("Loading vocab from", vocab_path)
print("Loading KB from", kb_path)
vocab2 = Vocab().from_disk(vocab_path)
kb2 = KnowledgeBase(vocab=vocab2)
kb2 = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab2)
kb2.load_bulk(kb_path)
print()
_print_kb(kb2)

View File

@ -374,7 +374,8 @@ class Errors:
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
"includes either the `text` or `tokens` key. For more info, see "
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
E139 = ("Knowledge Base for component '{name}' is empty.")
E139 = ("Knowledge Base for component '{name}' is empty. Use the methods "
"kb.add_entity and kb.add_alias to add entries.")
E140 = ("The list of entities, prior probabilities and entity vectors "
"should be of equal length.")
E141 = ("Entity vectors should be of length {required} instead of the "
@ -481,6 +482,8 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
"a string value from {expected} but got: '{arg}'")
E948 = ("Matcher.add received invalid 'patterns' argument: expected "

View File

@ -71,17 +71,25 @@ cdef class KnowledgeBase:
DOCS: https://spacy.io/api/kb
"""
def __init__(self, Vocab vocab, entity_vector_length=64):
self.vocab = vocab
def __init__(self, entity_vector_length):
"""Create a KnowledgeBase. Make sure to call kb.initialize() before using it."""
self.mem = Pool()
self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap()
self._alias_index = PreshMap()
self.vocab = None
def initialize(self, Vocab vocab):
self.vocab = vocab
self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
def require_vocab(self):
if self.vocab is None:
raise ValueError(Errors.E946)
@property
def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors"""
@ -94,12 +102,14 @@ cdef class KnowledgeBase:
return len(self._entry_index)
def get_entity_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._entry_index]
def get_size_aliases(self):
return len(self._alias_index)
def get_alias_strings(self):
self.require_vocab()
return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
@ -107,6 +117,7 @@ cdef class KnowledgeBase:
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end.
"""
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return if this entity was added before
@ -129,6 +140,7 @@ cdef class KnowledgeBase:
return entity_hash
cpdef set_entities(self, entity_list, freq_list, vector_list):
self.require_vocab()
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140)
@ -164,10 +176,12 @@ cdef class KnowledgeBase:
i += 1
def contains_entity(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index
def contains_alias(self, unicode alias):
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index
@ -176,6 +190,7 @@ cdef class KnowledgeBase:
For a given alias, add its potential entities and prior probabilies to the KB.
Return the alias_hash at the end
"""
self.require_vocab()
# Throw an error if the length of entities and probabilities are not the same
if not len(entities) == len(probabilities):
raise ValueError(Errors.E132.format(alias=alias,
@ -219,6 +234,7 @@ cdef class KnowledgeBase:
Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
"""
self.require_vocab()
# Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
@ -265,6 +281,7 @@ cdef class KnowledgeBase:
and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned.
"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
return []
@ -281,6 +298,7 @@ cdef class KnowledgeBase:
if entry_index != 0]
def get_vector(self, unicode entity):
self.require_vocab()
cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB
@ -293,6 +311,7 @@ cdef class KnowledgeBase:
def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base"""
self.require_vocab()
cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity]
@ -311,6 +330,7 @@ cdef class KnowledgeBase:
def dump(self, loc):
self.require_vocab()
cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities(), self.entity_vector_length)

View File

@ -27,6 +27,13 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model:
@registry.assets.register("spacy.KBFromFile.v1")
def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase:
vocab = Vocab().from_disk(vocab_path)
kb = KnowledgeBase(vocab=vocab)
kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(vocab)
kb.load_bulk(kb_path)
return kb
@registry.assets.register("spacy.EmptyKB.v1")
def empty_kb(entity_vector_length: int) -> KnowledgeBase:
kb = KnowledgeBase(entity_vector_length=entity_vector_length)
return kb

View File

@ -33,24 +33,31 @@ dropout = null
"""
DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"]
default_kb_config = """
[kb]
@assets = "spacy.EmptyKB.v1"
entity_vector_length = 64
"""
DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"]
@Language.factory(
"entity_linker",
requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"],
assigns=["token.ent_kb_id"],
default_config={
"kb": None, # TODO - what kind of default makes sense here?
"kb": DEFAULT_NEL_KB,
"model": DEFAULT_NEL_MODEL,
"labels_discard": [],
"incl_prior": True,
"incl_context": True,
"model": DEFAULT_NEL_MODEL,
},
)
def make_entity_linker(
nlp: Language,
name: str,
model: Model,
kb: Optional[KnowledgeBase],
kb: KnowledgeBase,
*,
labels_discard: Iterable[str],
incl_prior: bool,
@ -92,10 +99,10 @@ class EntityLinker(Pipe):
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
kb (KnowledgeBase): TODO:
labels_discard (Iterable[str]): TODO:
incl_prior (bool): TODO:
incl_context (bool): TODO:
kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases.
labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction.
incl_prior (bool): Whether or not to include prior probabilities from the KB in the model.
incl_context (bool): Whether or not to include the local context in the model.
DOCS: https://spacy.io/api/entitylinker#init
"""
@ -108,14 +115,12 @@ class EntityLinker(Pipe):
"incl_prior": incl_prior,
"incl_context": incl_context,
}
self.kb = kb
if self.kb is None:
# create an empty KB that should be filled by calling from_disk
self.kb = KnowledgeBase(vocab=vocab)
else:
del cfg["kb"] # we don't want to duplicate its serialization
if not isinstance(self.kb, KnowledgeBase):
if not isinstance(kb, KnowledgeBase):
raise ValueError(Errors.E990.format(type=type(self.kb)))
kb.initialize(vocab)
self.kb = kb
if "kb" in cfg:
del cfg["kb"] # we don't want to duplicate its serialization
self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False)
# how many neightbour sentences to take into account
@ -437,9 +442,8 @@ class EntityLinker(Pipe):
raise ValueError(Errors.E149)
def load_kb(p):
self.kb = KnowledgeBase(
vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]
)
self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"])
self.kb.initialize(self.vocab)
self.kb.load_bulk(p)
deserialize = {}

View File

@ -21,7 +21,8 @@ def assert_almost_equal(a, b):
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3])
@ -50,7 +51,8 @@ def test_kb_valid_entities(nlp):
def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -66,7 +68,8 @@ def test_kb_invalid_entities(nlp):
def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -80,7 +83,8 @@ def test_kb_invalid_probabilities(nlp):
def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
@ -96,7 +100,8 @@ def test_kb_invalid_combination(nlp):
def test_kb_invalid_entity_vector(nlp):
"""Test the invalid construction of a KB with non-matching entity vector lengths"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3])
@ -106,9 +111,44 @@ def test_kb_invalid_entity_vector(nlp):
mykb.add_entity(entity="Q2", freq=5, entity_vector=[2])
def test_kb_default(nlp):
"""Test that the default (empty) KB is loaded when not providing a config"""
entity_linker = nlp.add_pipe("entity_linker", config={})
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
assert entity_linker.kb.entity_vector_length == 64 # default value from pipeline.entity_linker
def test_kb_custom_length(nlp):
"""Test that the default (empty) KB can be configured with a custom entity length"""
entity_linker = nlp.add_pipe("entity_linker", config={"kb": {"entity_vector_length": 35}})
assert len(entity_linker.kb) == 0
assert entity_linker.kb.get_size_entities() == 0
assert entity_linker.kb.get_size_aliases() == 0
assert entity_linker.kb.entity_vector_length == 35
def test_kb_undefined(nlp):
"""Test that the EL can't train without defining a KB"""
entity_linker = nlp.add_pipe("entity_linker", config={})
with pytest.raises(ValueError):
entity_linker.begin_training()
def test_kb_empty(nlp):
"""Test that the EL can't train with an empty KB"""
config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}}
entity_linker = nlp.add_pipe("entity_linker", config=config)
assert len(entity_linker.kb) == 0
with pytest.raises(ValueError):
entity_linker.begin_training()
def test_candidate_generation(nlp):
"""Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -133,7 +173,8 @@ def test_candidate_generation(nlp):
def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -163,7 +204,8 @@ def test_append_alias(nlp):
def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
@ -184,7 +226,8 @@ def test_preserving_links_asdoc(nlp):
@registry.assets.register("myLocationsKB.v1")
def dummy_kb() -> KnowledgeBase:
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(entity_vector_length=1)
mykb.initialize(nlp.vocab)
# adding entities
mykb.add_entity(entity="Q1", freq=19, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=8, entity_vector=[1])
@ -289,7 +332,8 @@ def test_overfitting_IO():
# create artificial KB - assign same prior weight to the two russ cochran's
# Q2146908 (Russ Cochran): American golfer
# Q7381115 (Russ Cochran): publisher
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
mykb = KnowledgeBase(entity_vector_length=3)
mykb.initialize(nlp.vocab)
mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3])
mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7])
mykb.add_alias(

View File

@ -139,7 +139,8 @@ def test_issue4665():
def test_issue4674():
"""Test that setting entities with overlapping identifiers does not mess up IO"""
nlp = English()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
kb = KnowledgeBase(entity_vector_length=3)
kb.initialize(nlp.vocab)
vector1 = [0.9, 1.1, 1.01]
vector2 = [1.8, 2.25, 2.01]
with pytest.warns(UserWarning):
@ -156,7 +157,8 @@ def test_issue4674():
dir_path.mkdir()
file_path = dir_path / "kb"
kb.dump(str(file_path))
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
kb2 = KnowledgeBase(entity_vector_length=3)
kb2.initialize(nlp.vocab)
kb2.load_bulk(str(file_path))
assert kb2.get_size_entities() == 1

View File

@ -72,7 +72,8 @@ def entity_linker():
@registry.assets.register("TestIssue5230KB.v1")
def dummy_kb() -> KnowledgeBase:
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
kb.add_entity("test", 0.0, zeros((1, 1), dtype="f"))
return kb
@ -121,7 +122,8 @@ def test_writer_with_path_py35():
def test_save_and_load_knowledge_base():
nlp = Language()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb = KnowledgeBase(entity_vector_length=1)
kb.initialize(nlp.vocab)
with make_tempdir() as d:
path = d / "kb"
try:
@ -130,7 +132,8 @@ def test_save_and_load_knowledge_base():
pytest.fail(str(e))
try:
kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1)
kb_loaded = KnowledgeBase(entity_vector_length=1)
kb_loaded.initialize(nlp.vocab)
kb_loaded.load_bulk(path)
except Exception as e:
pytest.fail(str(e))

View File

@ -17,7 +17,8 @@ def test_serialize_kb_disk(en_vocab):
file_path = dir_path / "kb"
kb1.dump(str(file_path))
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
kb2 = KnowledgeBase(entity_vector_length=3)
kb2.initialize(en_vocab)
kb2.load_bulk(str(file_path))
# final assertions
@ -25,7 +26,8 @@ def test_serialize_kb_disk(en_vocab):
def _get_dummy_kb(vocab):
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
kb = KnowledgeBase(entity_vector_length=3)
kb.initialize(vocab)
kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3])
kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])