mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-29 17:33:10 +03:00
More robust set entities method in KB (#4794)
* add unit test for setting entities with duplicate identifiers * count the number of actual unique identifiers and throw duplicate warning
This commit is contained in:
parent
a067ded495
commit
f9b541f9ef
|
@ -81,7 +81,8 @@ class Warnings(object):
|
||||||
"Future versions may introduce a `n_process` argument for "
|
"Future versions may introduce a `n_process` argument for "
|
||||||
"parallel inference via multiprocessing.")
|
"parallel inference via multiprocessing.")
|
||||||
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
|
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
|
||||||
W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
|
W018 = ("Entity '{entity}' already exists in the Knowledge Base - "
|
||||||
|
"ignoring the duplicate entry.")
|
||||||
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
|
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
|
||||||
"previously loaded vectors. See Issue #3853.")
|
"previously loaded vectors. See Issue #3853.")
|
||||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||||
|
|
11
spacy/kb.pyx
11
spacy/kb.pyx
|
@ -136,19 +136,24 @@ cdef class KnowledgeBase:
|
||||||
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
||||||
raise ValueError(Errors.E140)
|
raise ValueError(Errors.E140)
|
||||||
|
|
||||||
nr_entities = len(entity_list)
|
nr_entities = len(set(entity_list))
|
||||||
self._entry_index = PreshMap(nr_entities+1)
|
self._entry_index = PreshMap(nr_entities+1)
|
||||||
self._entries = entry_vec(nr_entities+1)
|
self._entries = entry_vec(nr_entities+1)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
cdef KBEntryC entry
|
cdef KBEntryC entry
|
||||||
cdef hash_t entity_hash
|
cdef hash_t entity_hash
|
||||||
while i < nr_entities:
|
while i < len(entity_list):
|
||||||
|
# only process this entity if its unique ID hadn't been added before
|
||||||
|
entity_hash = self.vocab.strings.add(entity_list[i])
|
||||||
|
if entity_hash in self._entry_index:
|
||||||
|
user_warning(Warnings.W018.format(entity=entity_list[i]))
|
||||||
|
|
||||||
|
else:
|
||||||
entity_vector = vector_list[i]
|
entity_vector = vector_list[i]
|
||||||
if len(entity_vector) != self.entity_vector_length:
|
if len(entity_vector) != self.entity_vector_length:
|
||||||
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
||||||
|
|
||||||
entity_hash = self.vocab.strings.add(entity_list[i])
|
|
||||||
entry.entity_hash = entity_hash
|
entry.entity_hash = entity_hash
|
||||||
entry.freq = freq_list[i]
|
entry.freq = freq_list[i]
|
||||||
|
|
||||||
|
|
34
spacy/tests/regression/test_issue4674.py
Normal file
34
spacy/tests/regression/test_issue4674.py
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
|
from spacy.util import ensure_path
|
||||||
|
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.tests.util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue4674():
|
||||||
|
"""Test that setting entities with overlapping identifiers does not mess up IO"""
|
||||||
|
nlp = English()
|
||||||
|
kb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||||
|
|
||||||
|
vector1 = [0.9, 1.1, 1.01]
|
||||||
|
vector2 = [1.8, 2.25, 2.01]
|
||||||
|
kb.set_entities(entity_list=["Q1", "Q1"], freq_list=[32, 111], vector_list=[vector1, vector2])
|
||||||
|
|
||||||
|
assert kb.get_size_entities() == 1
|
||||||
|
|
||||||
|
# dumping to file & loading back in
|
||||||
|
with make_tempdir() as d:
|
||||||
|
dir_path = ensure_path(d)
|
||||||
|
if not dir_path.exists():
|
||||||
|
dir_path.mkdir()
|
||||||
|
file_path = dir_path / "kb"
|
||||||
|
kb.dump(str(file_path))
|
||||||
|
|
||||||
|
kb2 = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=3)
|
||||||
|
kb2.load_bulk(str(file_path))
|
||||||
|
|
||||||
|
assert kb2.get_size_entities() == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user