mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Fix kb.set_entities (#9463)
* avoid creating _vectors_table when also using c_add_vector * write to self._vectors_table directly in set_entities
This commit is contained in:
parent
068cae7755
commit
da578c3d3b
|
@ -96,6 +96,8 @@ cdef class KnowledgeBase:
|
||||||
def initialize_entities(self, int64_t nr_entities):
|
def initialize_entities(self, int64_t nr_entities):
|
||||||
self._entry_index = PreshMap(nr_entities + 1)
|
self._entry_index = PreshMap(nr_entities + 1)
|
||||||
self._entries = entry_vec(nr_entities + 1)
|
self._entries = entry_vec(nr_entities + 1)
|
||||||
|
|
||||||
|
def initialize_vectors(self, int64_t nr_entities):
|
||||||
self._vectors_table = float_matrix(nr_entities + 1)
|
self._vectors_table = float_matrix(nr_entities + 1)
|
||||||
|
|
||||||
def initialize_aliases(self, int64_t nr_aliases):
|
def initialize_aliases(self, int64_t nr_aliases):
|
||||||
|
@ -154,6 +156,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
nr_entities = len(set(entity_list))
|
nr_entities = len(set(entity_list))
|
||||||
self.initialize_entities(nr_entities)
|
self.initialize_entities(nr_entities)
|
||||||
|
self.initialize_vectors(nr_entities)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
cdef KBEntryC entry
|
cdef KBEntryC entry
|
||||||
|
@ -172,8 +175,8 @@ cdef class KnowledgeBase:
|
||||||
entry.entity_hash = entity_hash
|
entry.entity_hash = entity_hash
|
||||||
entry.freq = freq_list[i]
|
entry.freq = freq_list[i]
|
||||||
|
|
||||||
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
self._vectors_table[i] = entity_vector
|
||||||
entry.vector_index = vector_index
|
entry.vector_index = i
|
||||||
|
|
||||||
entry.feats_row = -1 # Features table currently not implemented
|
entry.feats_row = -1 # Features table currently not implemented
|
||||||
|
|
||||||
|
@ -386,6 +389,7 @@ cdef class KnowledgeBase:
|
||||||
nr_aliases = header[1]
|
nr_aliases = header[1]
|
||||||
entity_vector_length = header[2]
|
entity_vector_length = header[2]
|
||||||
self.initialize_entities(nr_entities)
|
self.initialize_entities(nr_entities)
|
||||||
|
self.initialize_vectors(nr_entities)
|
||||||
self.initialize_aliases(nr_aliases)
|
self.initialize_aliases(nr_aliases)
|
||||||
self.entity_vector_length = entity_vector_length
|
self.entity_vector_length = entity_vector_length
|
||||||
|
|
||||||
|
@ -509,6 +513,7 @@ cdef class KnowledgeBase:
|
||||||
reader.read_header(&nr_entities, &entity_vector_length)
|
reader.read_header(&nr_entities, &entity_vector_length)
|
||||||
|
|
||||||
self.initialize_entities(nr_entities)
|
self.initialize_entities(nr_entities)
|
||||||
|
self.initialize_vectors(nr_entities)
|
||||||
self.entity_vector_length = entity_vector_length
|
self.entity_vector_length = entity_vector_length
|
||||||
|
|
||||||
# STEP 1: load entity vectors
|
# STEP 1: load entity vectors
|
||||||
|
|
|
@ -154,6 +154,40 @@ def test_kb_serialize(nlp):
|
||||||
mykb.from_disk(d / "unknown" / "kb")
|
mykb.from_disk(d / "unknown" / "kb")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(9137)
|
||||||
|
def test_kb_serialize_2(nlp):
|
||||||
|
v = [5, 6, 7, 8]
|
||||||
|
kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||||
|
kb1.set_entities(["E1"], [1], [v])
|
||||||
|
assert kb1.get_vector("E1") == v
|
||||||
|
with make_tempdir() as d:
|
||||||
|
kb1.to_disk(d / "kb")
|
||||||
|
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||||
|
kb2.from_disk(d / "kb")
|
||||||
|
assert kb2.get_vector("E1") == v
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_set_entities(nlp):
|
||||||
|
""" Test that set_entities entirely overwrites the previous set of entities """
|
||||||
|
v = [5, 6, 7, 8]
|
||||||
|
v1 = [1, 1, 1, 0]
|
||||||
|
v2 = [2, 2, 2, 3]
|
||||||
|
kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||||
|
kb1.set_entities(["E0"], [1], [v])
|
||||||
|
assert kb1.get_entity_strings() == ["E0"]
|
||||||
|
kb1.set_entities(["E1", "E2"], [1, 9], [v1, v2])
|
||||||
|
assert set(kb1.get_entity_strings()) == {"E1", "E2"}
|
||||||
|
assert kb1.get_vector("E1") == v1
|
||||||
|
assert kb1.get_vector("E2") == v2
|
||||||
|
with make_tempdir() as d:
|
||||||
|
kb1.to_disk(d / "kb")
|
||||||
|
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||||
|
kb2.from_disk(d / "kb")
|
||||||
|
assert set(kb2.get_entity_strings()) == {"E1", "E2"}
|
||||||
|
assert kb2.get_vector("E1") == v1
|
||||||
|
assert kb2.get_vector("E2") == v2
|
||||||
|
|
||||||
|
|
||||||
def test_kb_serialize_vocab(nlp):
|
def test_kb_serialize_vocab(nlp):
|
||||||
"""Test serialization of the KB and custom strings"""
|
"""Test serialization of the KB and custom strings"""
|
||||||
entity = "MyFunnyID"
|
entity = "MyFunnyID"
|
||||||
|
|
Loading…
Reference in New Issue
Block a user