mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Fix kb.set_entities (#9463)
* avoid creating _vectors_table when also using c_add_vector * write to self._vectors_table directly in set_entities
This commit is contained in:
parent
068cae7755
commit
da578c3d3b
|
@ -96,6 +96,8 @@ cdef class KnowledgeBase:
|
|||
def initialize_entities(self, int64_t nr_entities):
|
||||
self._entry_index = PreshMap(nr_entities + 1)
|
||||
self._entries = entry_vec(nr_entities + 1)
|
||||
|
||||
def initialize_vectors(self, int64_t nr_entities):
|
||||
self._vectors_table = float_matrix(nr_entities + 1)
|
||||
|
||||
def initialize_aliases(self, int64_t nr_aliases):
|
||||
|
@ -154,6 +156,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
nr_entities = len(set(entity_list))
|
||||
self.initialize_entities(nr_entities)
|
||||
self.initialize_vectors(nr_entities)
|
||||
|
||||
i = 0
|
||||
cdef KBEntryC entry
|
||||
|
@ -172,8 +175,8 @@ cdef class KnowledgeBase:
|
|||
entry.entity_hash = entity_hash
|
||||
entry.freq = freq_list[i]
|
||||
|
||||
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
||||
entry.vector_index = vector_index
|
||||
self._vectors_table[i] = entity_vector
|
||||
entry.vector_index = i
|
||||
|
||||
entry.feats_row = -1 # Features table currently not implemented
|
||||
|
||||
|
@ -386,6 +389,7 @@ cdef class KnowledgeBase:
|
|||
nr_aliases = header[1]
|
||||
entity_vector_length = header[2]
|
||||
self.initialize_entities(nr_entities)
|
||||
self.initialize_vectors(nr_entities)
|
||||
self.initialize_aliases(nr_aliases)
|
||||
self.entity_vector_length = entity_vector_length
|
||||
|
||||
|
@ -509,6 +513,7 @@ cdef class KnowledgeBase:
|
|||
reader.read_header(&nr_entities, &entity_vector_length)
|
||||
|
||||
self.initialize_entities(nr_entities)
|
||||
self.initialize_vectors(nr_entities)
|
||||
self.entity_vector_length = entity_vector_length
|
||||
|
||||
# STEP 1: load entity vectors
|
||||
|
|
|
@ -154,6 +154,40 @@ def test_kb_serialize(nlp):
|
|||
mykb.from_disk(d / "unknown" / "kb")
|
||||
|
||||
|
||||
@pytest.mark.issue(9137)
|
||||
def test_kb_serialize_2(nlp):
|
||||
v = [5, 6, 7, 8]
|
||||
kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||
kb1.set_entities(["E1"], [1], [v])
|
||||
assert kb1.get_vector("E1") == v
|
||||
with make_tempdir() as d:
|
||||
kb1.to_disk(d / "kb")
|
||||
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||
kb2.from_disk(d / "kb")
|
||||
assert kb2.get_vector("E1") == v
|
||||
|
||||
|
||||
def test_kb_set_entities(nlp):
|
||||
""" Test that set_entities entirely overwrites the previous set of entities """
|
||||
v = [5, 6, 7, 8]
|
||||
v1 = [1, 1, 1, 0]
|
||||
v2 = [2, 2, 2, 3]
|
||||
kb1 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||
kb1.set_entities(["E0"], [1], [v])
|
||||
assert kb1.get_entity_strings() == ["E0"]
|
||||
kb1.set_entities(["E1", "E2"], [1, 9], [v1, v2])
|
||||
assert set(kb1.get_entity_strings()) == {"E1", "E2"}
|
||||
assert kb1.get_vector("E1") == v1
|
||||
assert kb1.get_vector("E2") == v2
|
||||
with make_tempdir() as d:
|
||||
kb1.to_disk(d / "kb")
|
||||
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=4)
|
||||
kb2.from_disk(d / "kb")
|
||||
assert set(kb2.get_entity_strings()) == {"E1", "E2"}
|
||||
assert kb2.get_vector("E1") == v1
|
||||
assert kb2.get_vector("E2") == v2
|
||||
|
||||
|
||||
def test_kb_serialize_vocab(nlp):
|
||||
"""Test serialization of the KB and custom strings"""
|
||||
entity = "MyFunnyID"
|
||||
|
|
Loading…
Reference in New Issue
Block a user