From 82347110f5ce5905b5acbde851c4b712b79130c1 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 4 Aug 2020 14:34:09 +0200 Subject: [PATCH] Default empty KB in EL component (#5872) * EL field documentation * documentation consistent with docs * default empty KB, initialize vocab separately * formatting * add test for changing the default entity vector length * update comment --- examples/training/create_kb.py | 6 +- spacy/errors.py | 5 +- spacy/kb.pyx | 24 ++++++- spacy/ml/models/entity_linker.py | 9 ++- spacy/pipeline/entity_linker.py | 38 ++++++----- spacy/tests/pipeline/test_entity_linker.py | 64 ++++++++++++++++--- spacy/tests/regression/test_issue4501-5000.py | 6 +- spacy/tests/regression/test_issue5230.py | 9 ++- spacy/tests/serialize/test_serialize_kb.py | 6 +- 9 files changed, 127 insertions(+), 40 deletions(-) diff --git a/examples/training/create_kb.py b/examples/training/create_kb.py index 5b17bb59e..0c6e29226 100644 --- a/examples/training/create_kb.py +++ b/examples/training/create_kb.py @@ -48,7 +48,8 @@ def main(model, output_dir=None): # You can change the dimension of vectors in your KB by using an encoder that changes the dimensionality. # For simplicity, we'll just use the original vector dimension here instead. vectors_dim = nlp.vocab.vectors.shape[1] - kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=vectors_dim) + kb = KnowledgeBase(entity_vector_length=vectors_dim) + kb.initialize(nlp.vocab) # set up the data entity_ids = [] @@ -95,7 +96,8 @@ def main(model, output_dir=None): print("Loading vocab from", vocab_path) print("Loading KB from", kb_path) vocab2 = Vocab().from_disk(vocab_path) - kb2 = KnowledgeBase(vocab=vocab2) + kb2 = KnowledgeBase(entity_vector_length=1) + kb.initialize(vocab2) kb2.load_bulk(kb_path) print() _print_kb(kb2) diff --git a/spacy/errors.py b/spacy/errors.py index 124572b0b..6e595fe33 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -374,7 +374,8 @@ class Errors: E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input " "includes either the `text` or `tokens` key. For more info, see " "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl") - E139 = ("Knowledge Base for component '{name}' is empty.") + E139 = ("Knowledge Base for component '{name}' is empty. Use the methods " + "kb.add_entity and kb.add_alias to add entries.") E140 = ("The list of entities, prior probabilities and entity vectors " "should be of equal length.") E141 = ("Entity vectors should be of length {required} instead of the " @@ -481,6 +482,8 @@ class Errors: E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") # TODO: fix numbering after merging develop into master + E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to " + "call kb.initialize()?") E947 = ("Matcher.add received invalid 'greedy' argument: expected " "a string value from {expected} but got: '{arg}'") E948 = ("Matcher.add received invalid 'patterns' argument: expected " diff --git a/spacy/kb.pyx b/spacy/kb.pyx index 3f226596c..9035f7e6a 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -71,17 +71,25 @@ cdef class KnowledgeBase: DOCS: https://spacy.io/api/kb """ - def __init__(self, Vocab vocab, entity_vector_length=64): - self.vocab = vocab + def __init__(self, entity_vector_length): + """Create a KnowledgeBase. Make sure to call kb.initialize() before using it.""" self.mem = Pool() self.entity_vector_length = entity_vector_length self._entry_index = PreshMap() self._alias_index = PreshMap() + self.vocab = None + + def initialize(self, Vocab vocab): + self.vocab = vocab self.vocab.strings.add("") self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) + def require_vocab(self): + if self.vocab is None: + raise ValueError(Errors.E946) + @property def entity_vector_length(self): """RETURNS (uint64): length of the entity vectors""" @@ -94,12 +102,14 @@ cdef class KnowledgeBase: return len(self._entry_index) def get_entity_strings(self): + self.require_vocab() return [self.vocab.strings[x] for x in self._entry_index] def get_size_aliases(self): return len(self._alias_index) def get_alias_strings(self): + self.require_vocab() return [self.vocab.strings[x] for x in self._alias_index] def add_entity(self, unicode entity, float freq, vector[float] entity_vector): @@ -107,6 +117,7 @@ cdef class KnowledgeBase: Add an entity to the KB, optionally specifying its log probability based on corpus frequency Return the hash of the entity ID/name at the end. """ + self.require_vocab() cdef hash_t entity_hash = self.vocab.strings.add(entity) # Return if this entity was added before @@ -129,6 +140,7 @@ cdef class KnowledgeBase: return entity_hash cpdef set_entities(self, entity_list, freq_list, vector_list): + self.require_vocab() if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list): raise ValueError(Errors.E140) @@ -164,10 +176,12 @@ cdef class KnowledgeBase: i += 1 def contains_entity(self, unicode entity): + self.require_vocab() cdef hash_t entity_hash = self.vocab.strings.add(entity) return entity_hash in self._entry_index def contains_alias(self, unicode alias): + self.require_vocab() cdef hash_t alias_hash = self.vocab.strings.add(alias) return alias_hash in self._alias_index @@ -176,6 +190,7 @@ cdef class KnowledgeBase: For a given alias, add its potential entities and prior probabilies to the KB. Return the alias_hash at the end """ + self.require_vocab() # Throw an error if the length of entities and probabilities are not the same if not len(entities) == len(probabilities): raise ValueError(Errors.E132.format(alias=alias, @@ -219,6 +234,7 @@ cdef class KnowledgeBase: Throw an error if this entity+prior prob would exceed the sum of 1. For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. """ + self.require_vocab() # Check if the alias exists in the KB cdef hash_t alias_hash = self.vocab.strings[alias] if not alias_hash in self._alias_index: @@ -265,6 +281,7 @@ cdef class KnowledgeBase: and the prior probability of that alias resolving to that entity. If the alias is not known in the KB, and empty list is returned. """ + self.require_vocab() cdef hash_t alias_hash = self.vocab.strings[alias] if not alias_hash in self._alias_index: return [] @@ -281,6 +298,7 @@ cdef class KnowledgeBase: if entry_index != 0] def get_vector(self, unicode entity): + self.require_vocab() cdef hash_t entity_hash = self.vocab.strings[entity] # Return an empty list if this entity is unknown in this KB @@ -293,6 +311,7 @@ cdef class KnowledgeBase: def get_prior_prob(self, unicode entity, unicode alias): """ Return the prior probability of a given alias being linked to a given entity, or return 0.0 when this combination is not known in the knowledge base""" + self.require_vocab() cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t entity_hash = self.vocab.strings[entity] @@ -311,6 +330,7 @@ cdef class KnowledgeBase: def dump(self, loc): + self.require_vocab() cdef Writer writer = Writer(loc) writer.write_header(self.get_size_entities(), self.entity_vector_length) diff --git a/spacy/ml/models/entity_linker.py b/spacy/ml/models/entity_linker.py index f61fe2d5f..f96d50a7b 100644 --- a/spacy/ml/models/entity_linker.py +++ b/spacy/ml/models/entity_linker.py @@ -27,6 +27,13 @@ def build_nel_encoder(tok2vec: Model, nO: Optional[int] = None) -> Model: @registry.assets.register("spacy.KBFromFile.v1") def load_kb(vocab_path: str, kb_path: str) -> KnowledgeBase: vocab = Vocab().from_disk(vocab_path) - kb = KnowledgeBase(vocab=vocab) + kb = KnowledgeBase(entity_vector_length=1) + kb.initialize(vocab) kb.load_bulk(kb_path) return kb + + +@registry.assets.register("spacy.EmptyKB.v1") +def empty_kb(entity_vector_length: int) -> KnowledgeBase: + kb = KnowledgeBase(entity_vector_length=entity_vector_length) + return kb diff --git a/spacy/pipeline/entity_linker.py b/spacy/pipeline/entity_linker.py index 742b349e5..923d925dc 100644 --- a/spacy/pipeline/entity_linker.py +++ b/spacy/pipeline/entity_linker.py @@ -33,24 +33,31 @@ dropout = null """ DEFAULT_NEL_MODEL = Config().from_str(default_model_config)["model"] +default_kb_config = """ +[kb] +@assets = "spacy.EmptyKB.v1" +entity_vector_length = 64 +""" +DEFAULT_NEL_KB = Config().from_str(default_kb_config)["kb"] + @Language.factory( "entity_linker", requires=["doc.ents", "doc.sents", "token.ent_iob", "token.ent_type"], assigns=["token.ent_kb_id"], default_config={ - "kb": None, # TODO - what kind of default makes sense here? + "kb": DEFAULT_NEL_KB, + "model": DEFAULT_NEL_MODEL, "labels_discard": [], "incl_prior": True, "incl_context": True, - "model": DEFAULT_NEL_MODEL, }, ) def make_entity_linker( nlp: Language, name: str, model: Model, - kb: Optional[KnowledgeBase], + kb: KnowledgeBase, *, labels_discard: Iterable[str], incl_prior: bool, @@ -92,10 +99,10 @@ class EntityLinker(Pipe): model (thinc.api.Model): The Thinc Model powering the pipeline component. name (str): The component instance name, used to add entries to the losses during training. - kb (KnowledgeBase): TODO: - labels_discard (Iterable[str]): TODO: - incl_prior (bool): TODO: - incl_context (bool): TODO: + kb (KnowledgeBase): The KnowledgeBase holding all entities and their aliases. + labels_discard (Iterable[str]): NER labels that will automatically get a "NIL" prediction. + incl_prior (bool): Whether or not to include prior probabilities from the KB in the model. + incl_context (bool): Whether or not to include the local context in the model. DOCS: https://spacy.io/api/entitylinker#init """ @@ -108,14 +115,12 @@ class EntityLinker(Pipe): "incl_prior": incl_prior, "incl_context": incl_context, } - self.kb = kb - if self.kb is None: - # create an empty KB that should be filled by calling from_disk - self.kb = KnowledgeBase(vocab=vocab) - else: - del cfg["kb"] # we don't want to duplicate its serialization - if not isinstance(self.kb, KnowledgeBase): + if not isinstance(kb, KnowledgeBase): raise ValueError(Errors.E990.format(type=type(self.kb))) + kb.initialize(vocab) + self.kb = kb + if "kb" in cfg: + del cfg["kb"] # we don't want to duplicate its serialization self.cfg = dict(cfg) self.distance = CosineDistance(normalize=False) # how many neightbour sentences to take into account @@ -437,9 +442,8 @@ class EntityLinker(Pipe): raise ValueError(Errors.E149) def load_kb(p): - self.kb = KnowledgeBase( - vocab=self.vocab, entity_vector_length=self.cfg["entity_width"] - ) + self.kb = KnowledgeBase(entity_vector_length=self.cfg["entity_width"]) + self.kb.initialize(self.vocab) self.kb.load_bulk(p) deserialize = {} diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 4002eafe3..236d0e0d5 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -21,7 +21,8 @@ def assert_almost_equal(a, b): def test_kb_valid_entities(nlp): """Test the valid construction of a KB with 3 entities and two aliases""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) + mykb = KnowledgeBase(entity_vector_length=3) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3]) @@ -50,7 +51,8 @@ def test_kb_valid_entities(nlp): def test_kb_invalid_entities(nlp): """Test the invalid construction of a KB with an alias linked to a non-existing entity""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + mykb = KnowledgeBase(entity_vector_length=1) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) @@ -66,7 +68,8 @@ def test_kb_invalid_entities(nlp): def test_kb_invalid_probabilities(nlp): """Test the invalid construction of a KB with wrong prior probabilities""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + mykb = KnowledgeBase(entity_vector_length=1) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) @@ -80,7 +83,8 @@ def test_kb_invalid_probabilities(nlp): def test_kb_invalid_combination(nlp): """Test the invalid construction of a KB with non-matching entity and probability lists""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + mykb = KnowledgeBase(entity_vector_length=1) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) @@ -96,7 +100,8 @@ def test_kb_invalid_combination(nlp): def test_kb_invalid_entity_vector(nlp): """Test the invalid construction of a KB with non-matching entity vector lengths""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) + mykb = KnowledgeBase(entity_vector_length=3) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3]) @@ -106,9 +111,44 @@ def test_kb_invalid_entity_vector(nlp): mykb.add_entity(entity="Q2", freq=5, entity_vector=[2]) +def test_kb_default(nlp): + """Test that the default (empty) KB is loaded when not providing a config""" + entity_linker = nlp.add_pipe("entity_linker", config={}) + assert len(entity_linker.kb) == 0 + assert entity_linker.kb.get_size_entities() == 0 + assert entity_linker.kb.get_size_aliases() == 0 + assert entity_linker.kb.entity_vector_length == 64 # default value from pipeline.entity_linker + + +def test_kb_custom_length(nlp): + """Test that the default (empty) KB can be configured with a custom entity length""" + entity_linker = nlp.add_pipe("entity_linker", config={"kb": {"entity_vector_length": 35}}) + assert len(entity_linker.kb) == 0 + assert entity_linker.kb.get_size_entities() == 0 + assert entity_linker.kb.get_size_aliases() == 0 + assert entity_linker.kb.entity_vector_length == 35 + + +def test_kb_undefined(nlp): + """Test that the EL can't train without defining a KB""" + entity_linker = nlp.add_pipe("entity_linker", config={}) + with pytest.raises(ValueError): + entity_linker.begin_training() + + +def test_kb_empty(nlp): + """Test that the EL can't train with an empty KB""" + config = {"kb": {"@assets": "spacy.EmptyKB.v1", "entity_vector_length": 342}} + entity_linker = nlp.add_pipe("entity_linker", config=config) + assert len(entity_linker.kb) == 0 + with pytest.raises(ValueError): + entity_linker.begin_training() + + def test_candidate_generation(nlp): """Test correct candidate generation""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + mykb = KnowledgeBase(entity_vector_length=1) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) @@ -133,7 +173,8 @@ def test_candidate_generation(nlp): def test_append_alias(nlp): """Test that we can append additional alias-entity pairs""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + mykb = KnowledgeBase(entity_vector_length=1) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) @@ -163,7 +204,8 @@ def test_append_alias(nlp): def test_append_invalid_alias(nlp): """Test that append an alias will throw an error if prior probs are exceeding 1""" - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + mykb = KnowledgeBase(entity_vector_length=1) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) @@ -184,7 +226,8 @@ def test_preserving_links_asdoc(nlp): @registry.assets.register("myLocationsKB.v1") def dummy_kb() -> KnowledgeBase: - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + mykb = KnowledgeBase(entity_vector_length=1) + mykb.initialize(nlp.vocab) # adding entities mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) @@ -289,7 +332,8 @@ def test_overfitting_IO(): # create artificial KB - assign same prior weight to the two russ cochran's # Q2146908 (Russ Cochran): American golfer # Q7381115 (Russ Cochran): publisher - mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) + mykb = KnowledgeBase(entity_vector_length=3) + mykb.initialize(nlp.vocab) mykb.add_entity(entity="Q2146908", freq=12, entity_vector=[6, -4, 3]) mykb.add_entity(entity="Q7381115", freq=12, entity_vector=[9, 1, -7]) mykb.add_alias( diff --git a/spacy/tests/regression/test_issue4501-5000.py b/spacy/tests/regression/test_issue4501-5000.py index 08a21e690..0b3b4a9fc 100644 --- a/spacy/tests/regression/test_issue4501-5000.py +++ b/spacy/tests/regression/test_issue4501-5000.py @@ -139,7 +139,8 @@ def test_issue4665(): def test_issue4674(): """Test that setting entities with overlapping identifiers does not mess up IO""" nlp = English() - kb = KnowledgeBase(nlp.vocab, entity_vector_length=3) + kb = KnowledgeBase(entity_vector_length=3) + kb.initialize(nlp.vocab) vector1 = [0.9, 1.1, 1.01] vector2 = [1.8, 2.25, 2.01] with pytest.warns(UserWarning): @@ -156,7 +157,8 @@ def test_issue4674(): dir_path.mkdir() file_path = dir_path / "kb" kb.dump(str(file_path)) - kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3) + kb2 = KnowledgeBase(entity_vector_length=3) + kb2.initialize(nlp.vocab) kb2.load_bulk(str(file_path)) assert kb2.get_size_entities() == 1 diff --git a/spacy/tests/regression/test_issue5230.py b/spacy/tests/regression/test_issue5230.py index ae9ed1844..31292b700 100644 --- a/spacy/tests/regression/test_issue5230.py +++ b/spacy/tests/regression/test_issue5230.py @@ -72,7 +72,8 @@ def entity_linker(): @registry.assets.register("TestIssue5230KB.v1") def dummy_kb() -> KnowledgeBase: - kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + kb = KnowledgeBase(entity_vector_length=1) + kb.initialize(nlp.vocab) kb.add_entity("test", 0.0, zeros((1, 1), dtype="f")) return kb @@ -121,7 +122,8 @@ def test_writer_with_path_py35(): def test_save_and_load_knowledge_base(): nlp = Language() - kb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + kb = KnowledgeBase(entity_vector_length=1) + kb.initialize(nlp.vocab) with make_tempdir() as d: path = d / "kb" try: @@ -130,7 +132,8 @@ def test_save_and_load_knowledge_base(): pytest.fail(str(e)) try: - kb_loaded = KnowledgeBase(nlp.vocab, entity_vector_length=1) + kb_loaded = KnowledgeBase(entity_vector_length=1) + kb_loaded.initialize(nlp.vocab) kb_loaded.load_bulk(path) except Exception as e: pytest.fail(str(e)) diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py index 91036a496..3f33c6f06 100644 --- a/spacy/tests/serialize/test_serialize_kb.py +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -17,7 +17,8 @@ def test_serialize_kb_disk(en_vocab): file_path = dir_path / "kb" kb1.dump(str(file_path)) - kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3) + kb2 = KnowledgeBase(entity_vector_length=3) + kb2.initialize(en_vocab) kb2.load_bulk(str(file_path)) # final assertions @@ -25,7 +26,8 @@ def test_serialize_kb_disk(en_vocab): def _get_dummy_kb(vocab): - kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) + kb = KnowledgeBase(entity_vector_length=3) + kb.initialize(vocab) kb.add_entity(entity="Q53", freq=33, entity_vector=[0, 5, 3]) kb.add_entity(entity="Q17", freq=2, entity_vector=[7, 1, 0])