mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
More robust set entities method in KB (#4794)
* add unit test for setting entities with duplicate identifiers * count the number of actual unique identifiers and throw duplicate warning
This commit is contained in:
parent
a067ded495
commit
f9b541f9ef
|
@ -81,7 +81,8 @@ class Warnings(object):
|
|||
"Future versions may introduce a `n_process` argument for "
|
||||
"parallel inference via multiprocessing.")
|
||||
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
|
||||
W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
|
||||
W018 = ("Entity '{entity}' already exists in the Knowledge Base - "
|
||||
"ignoring the duplicate entry.")
|
||||
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
|
||||
"previously loaded vectors. See Issue #3853.")
|
||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||
|
|
31
spacy/kb.pyx
31
spacy/kb.pyx
|
@ -136,29 +136,34 @@ cdef class KnowledgeBase:
|
|||
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
||||
raise ValueError(Errors.E140)
|
||||
|
||||
nr_entities = len(entity_list)
|
||||
nr_entities = len(set(entity_list))
|
||||
self._entry_index = PreshMap(nr_entities+1)
|
||||
self._entries = entry_vec(nr_entities+1)
|
||||
|
||||
i = 0
|
||||
cdef KBEntryC entry
|
||||
cdef hash_t entity_hash
|
||||
while i < nr_entities:
|
||||
entity_vector = vector_list[i]
|
||||
if len(entity_vector) != self.entity_vector_length:
|
||||
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
||||
|
||||
while i < len(entity_list):
|
||||
# only process this entity if its unique ID hadn't been added before
|
||||
entity_hash = self.vocab.strings.add(entity_list[i])
|
||||
entry.entity_hash = entity_hash
|
||||
entry.freq = freq_list[i]
|
||||
if entity_hash in self._entry_index:
|
||||
user_warning(Warnings.W018.format(entity=entity_list[i]))
|
||||
|
||||
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
||||
entry.vector_index = vector_index
|
||||
else:
|
||||
entity_vector = vector_list[i]
|
||||
if len(entity_vector) != self.entity_vector_length:
|
||||
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
||||
|
||||
entry.feats_row = -1 # Features table currently not implemented
|
||||
entry.entity_hash = entity_hash
|
||||
entry.freq = freq_list[i]
|
||||
|
||||
self._entries[i+1] = entry
|
||||
self._entry_index[entity_hash] = i+1
|
||||
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
||||
entry.vector_index = vector_index
|
||||
|
||||
entry.feats_row = -1 # Features table currently not implemented
|
||||
|
||||
self._entries[i+1] = entry
|
||||
self._entry_index[entity_hash] = i+1
|
||||
|
||||
i += 1
|
||||
|
||||
|
|
34
spacy/tests/regression/test_issue4674.py
Normal file
34
spacy/tests/regression/test_issue4674.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from spacy.util import ensure_path
|
||||
|
||||
from spacy.lang.en import English
|
||||
from spacy.tests.util import make_tempdir
|
||||
|
||||
|
||||
def test_issue4674():
|
||||
"""Test that setting entities with overlapping identifiers does not mess up IO"""
|
||||
nlp = English()
|
||||
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
||||
vector1 = [0.9, 1.1, 1.01]
|
||||
vector2 = [1.8, 2.25, 2.01]
|
||||
kb.set_entities(entity_list=["Q1", "Q1"], freq_list=[32, 111], vector_list=[vector1, vector2])
|
||||
|
||||
assert kb.get_size_entities() == 1
|
||||
|
||||
# dumping to file & loading back in
|
||||
with make_tempdir() as d:
|
||||
dir_path = ensure_path(d)
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir()
|
||||
file_path = dir_path / "kb"
|
||||
kb.dump(str(file_path))
|
||||
|
||||
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
|
||||
kb2.load_bulk(str(file_path))
|
||||
|
||||
assert kb2.get_size_entities() == 1
|
||||
|
Loading…
Reference in New Issue
Block a user