diff --git a/spacy/kb.pyx b/spacy/kb.pyx index 16d63a4d3..fed3009da 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -96,6 +96,8 @@ cdef class KnowledgeBase: def initialize_entities(self, int64_t nr_entities): self._entry_index = PreshMap(nr_entities + 1) self._entries = entry_vec(nr_entities + 1) + + def initialize_vectors(self, int64_t nr_entities): self._vectors_table = float_matrix(nr_entities + 1) def initialize_aliases(self, int64_t nr_aliases): @@ -154,6 +156,7 @@ cdef class KnowledgeBase: nr_entities = len(set(entity_list)) self.initialize_entities(nr_entities) + self.initialize_vectors(nr_entities) i = 0 cdef KBEntryC entry @@ -172,8 +175,8 @@ cdef class KnowledgeBase: entry.entity_hash = entity_hash entry.freq = freq_list[i] - vector_index = self.c_add_vector(entity_vector=vector_list[i]) - entry.vector_index = vector_index + self._vectors_table[i] = entity_vector + entry.vector_index = i entry.feats_row = -1 # Features table currently not implemented @@ -386,6 +389,7 @@ cdef class KnowledgeBase: nr_aliases = header[1] entity_vector_length = header[2] self.initialize_entities(nr_entities) + self.initialize_vectors(nr_entities) self.initialize_aliases(nr_aliases) self.entity_vector_length = entity_vector_length @@ -509,6 +513,7 @@ cdef class KnowledgeBase: reader.read_header(&nr_entities, &entity_vector_length) self.initialize_entities(nr_entities) + self.initialize_vectors(nr_entities) self.entity_vector_length = entity_vector_length # STEP 1: load entity vectors diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index b97795344..247443489 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -154,6 +154,40 @@ def test_kb_serialize(nlp): mykb.from_disk(d / "unknown" / "kb") +@pytest.mark.issue(9137) +def test_kb_serialize_2(nlp): + v = [5, 6, 7, 8] + kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4) + kb1.set_entities(["E1"], [1], [v]) + assert kb1.get_vector("E1") == v + with make_tempdir() as d: + kb1.to_disk(d / "kb") + kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4) + kb2.from_disk(d / "kb") + assert kb2.get_vector("E1") == v + + +def test_kb_set_entities(nlp): + """ Test that set_entities entirely overwrites the previous set of entities """ + v = [5, 6, 7, 8] + v1 = [1, 1, 1, 0] + v2 = [2, 2, 2, 3] + kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4) + kb1.set_entities(["E0"], [1], [v]) + assert kb1.get_entity_strings() == ["E0"] + kb1.set_entities(["E1", "E2"], [1, 9], [v1, v2]) + assert set(kb1.get_entity_strings()) == {"E1", "E2"} + assert kb1.get_vector("E1") == v1 + assert kb1.get_vector("E2") == v2 + with make_tempdir() as d: + kb1.to_disk(d / "kb") + kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4) + kb2.from_disk(d / "kb") + assert set(kb2.get_entity_strings()) == {"E1", "E2"} + assert kb2.get_vector("E1") == v1 + assert kb2.get_vector("E2") == v2 + + def test_kb_serialize_vocab(nlp): """Test serialization of the KB and custom strings""" entity = "MyFunnyID"