Fix kb.set_entities (#9463)

* avoid creating _vectors_table when also using c_add_vector

* write to self._vectors_table directly in set_entities
This commit is contained in:
Sofie Van Landeghem 2021-10-19 09:39:17 +02:00 committed by GitHub
parent 068cae7755
commit da578c3d3b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 41 additions and 2 deletions

View File

@ -96,6 +96,8 @@ cdef class KnowledgeBase:
def initialize_entities(self, int64_t nr_entities): def initialize_entities(self, int64_t nr_entities):
self._entry_index = PreshMap(nr_entities + 1) self._entry_index = PreshMap(nr_entities + 1)
self._entries = entry_vec(nr_entities + 1) self._entries = entry_vec(nr_entities + 1)
def initialize_vectors(self, int64_t nr_entities):
self._vectors_table = float_matrix(nr_entities + 1) self._vectors_table = float_matrix(nr_entities + 1)
def initialize_aliases(self, int64_t nr_aliases): def initialize_aliases(self, int64_t nr_aliases):
@ -154,6 +156,7 @@ cdef class KnowledgeBase:
nr_entities = len(set(entity_list)) nr_entities = len(set(entity_list))
self.initialize_entities(nr_entities) self.initialize_entities(nr_entities)
self.initialize_vectors(nr_entities)
i = 0 i = 0
cdef KBEntryC entry cdef KBEntryC entry
@ -172,8 +175,8 @@ cdef class KnowledgeBase:
entry.entity_hash = entity_hash entry.entity_hash = entity_hash
entry.freq = freq_list[i] entry.freq = freq_list[i]
vector_index = self.c_add_vector(entity_vector=vector_list[i]) self._vectors_table[i] = entity_vector
entry.vector_index = vector_index entry.vector_index = i
entry.feats_row = -1 # Features table currently not implemented entry.feats_row = -1 # Features table currently not implemented
@ -386,6 +389,7 @@ cdef class KnowledgeBase:
nr_aliases = header[1] nr_aliases = header[1]
entity_vector_length = header[2] entity_vector_length = header[2]
self.initialize_entities(nr_entities) self.initialize_entities(nr_entities)
self.initialize_vectors(nr_entities)
self.initialize_aliases(nr_aliases) self.initialize_aliases(nr_aliases)
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
@ -509,6 +513,7 @@ cdef class KnowledgeBase:
reader.read_header(&nr_entities, &entity_vector_length) reader.read_header(&nr_entities, &entity_vector_length)
self.initialize_entities(nr_entities) self.initialize_entities(nr_entities)
self.initialize_vectors(nr_entities)
self.entity_vector_length = entity_vector_length self.entity_vector_length = entity_vector_length
# STEP 1: load entity vectors # STEP 1: load entity vectors

View File

@ -154,6 +154,40 @@ def test_kb_serialize(nlp):
mykb.from_disk(d / "unknown" / "kb") mykb.from_disk(d / "unknown" / "kb")
@pytest.mark.issue(9137)
def test_kb_serialize_2(nlp):
v = [5, 6, 7, 8]
kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
kb1.set_entities(["E1"], [1], [v])
assert kb1.get_vector("E1") == v
with make_tempdir() as d:
kb1.to_disk(d / "kb")
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
kb2.from_disk(d / "kb")
assert kb2.get_vector("E1") == v
def test_kb_set_entities(nlp):
""" Test that set_entities entirely overwrites the previous set of entities """
v = [5, 6, 7, 8]
v1 = [1, 1, 1, 0]
v2 = [2, 2, 2, 3]
kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
kb1.set_entities(["E0"], [1], [v])
assert kb1.get_entity_strings() == ["E0"]
kb1.set_entities(["E1", "E2"], [1, 9], [v1, v2])
assert set(kb1.get_entity_strings()) == {"E1", "E2"}
assert kb1.get_vector("E1") == v1
assert kb1.get_vector("E2") == v2
with make_tempdir() as d:
kb1.to_disk(d / "kb")
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
kb2.from_disk(d / "kb")
assert set(kb2.get_entity_strings()) == {"E1", "E2"}
assert kb2.get_vector("E1") == v1
assert kb2.get_vector("E2") == v2
def test_kb_serialize_vocab(nlp): def test_kb_serialize_vocab(nlp):
"""Test serialization of the KB and custom strings""" """Test serialization of the KB and custom strings"""
entity = "MyFunnyID" entity = "MyFunnyID"