More robust set entities method in KB (#4794)

* add unit test for setting entities with duplicate identifiers

* count the number of actual unique identifiers and throw duplicate warning
This commit is contained in:
Sofie Van Landeghem 2019-12-13 10:45:29 +01:00 committed by Matthew Honnibal
parent a067ded495
commit f9b541f9ef
3 changed files with 54 additions and 14 deletions

View File

@ -81,7 +81,8 @@ class Warnings(object):
"Future versions may introduce a `n_process` argument for "
"parallel inference via multiprocessing.")
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
W018 = ("Entity '{entity}' already exists in the Knowledge Base - "
"ignoring the duplicate entry.")
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
"previously loaded vectors. See Issue #3853.")
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "

View File

@ -136,29 +136,34 @@ cdef class KnowledgeBase:
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140)
nr_entities = len(entity_list)
nr_entities = len(set(entity_list))
self._entry_index = PreshMap(nr_entities+1)
self._entries = entry_vec(nr_entities+1)
i = 0
cdef KBEntryC entry
cdef hash_t entity_hash
while i < nr_entities:
entity_vector = vector_list[i]
if len(entity_vector) != self.entity_vector_length:
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
while i < len(entity_list):
# only process this entity if its unique ID hadn't been added before
entity_hash = self.vocab.strings.add(entity_list[i])
entry.entity_hash = entity_hash
entry.freq = freq_list[i]
if entity_hash in self._entry_index:
user_warning(Warnings.W018.format(entity=entity_list[i]))
vector_index = self.c_add_vector(entity_vector=vector_list[i])
entry.vector_index = vector_index
else:
entity_vector = vector_list[i]
if len(entity_vector) != self.entity_vector_length:
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
entry.feats_row = -1 # Features table currently not implemented
entry.entity_hash = entity_hash
entry.freq = freq_list[i]
self._entries[i+1] = entry
self._entry_index[entity_hash] = i+1
vector_index = self.c_add_vector(entity_vector=vector_list[i])
entry.vector_index = vector_index
entry.feats_row = -1 # Features table currently not implemented
self._entries[i+1] = entry
self._entry_index[entity_hash] = i+1
i += 1

View File

@ -0,0 +1,34 @@
# coding: utf-8
from __future__ import unicode_literals
from spacy.kb import KnowledgeBase
from spacy.util import ensure_path
from spacy.lang.en import English
from spacy.tests.util import make_tempdir
def test_issue4674():
"""Test that setting entities with overlapping identifiers does not mess up IO"""
nlp = English()
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
vector1 = [0.9, 1.1, 1.01]
vector2 = [1.8, 2.25, 2.01]
kb.set_entities(entity_list=["Q1", "Q1"], freq_list=[32, 111], vector_list=[vector1, vector2])
assert kb.get_size_entities() == 1
# dumping to file & loading back in
with make_tempdir() as d:
dir_path = ensure_path(d)
if not dir_path.exists():
dir_path.mkdir()
file_path = dir_path / "kb"
kb.dump(str(file_path))
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
kb2.load_bulk(str(file_path))
assert kb2.get_size_entities() == 1