mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
unit test for KB serialization
This commit is contained in:
parent
3e0cb69065
commit
54d0cea062
|
@ -442,11 +442,11 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("dumping kb1")
|
print("dumping kb1")
|
||||||
|
print(KB_FILE, type(KB_FILE))
|
||||||
kb1.dump(KB_FILE)
|
kb1.dump(KB_FILE)
|
||||||
|
|
||||||
# STEP 4 : read KB back in from file
|
# STEP 4 : read KB back in from file
|
||||||
|
|
||||||
nlp3 = spacy.load('en_core_web_sm')
|
|
||||||
kb3 = KnowledgeBase(vocab=my_vocab)
|
kb3 = KnowledgeBase(vocab=my_vocab)
|
||||||
|
|
||||||
print("loading kb3")
|
print("loading kb3")
|
||||||
|
|
|
@ -19,6 +19,7 @@ cdef class Candidate:
|
||||||
|
|
||||||
cdef readonly KnowledgeBase kb
|
cdef readonly KnowledgeBase kb
|
||||||
cdef hash_t entity_hash
|
cdef hash_t entity_hash
|
||||||
|
cdef float entity_freq
|
||||||
cdef hash_t alias_hash
|
cdef hash_t alias_hash
|
||||||
cdef float prior_prob
|
cdef float prior_prob
|
||||||
|
|
||||||
|
|
|
@ -26,9 +26,10 @@ from libcpp.vector cimport vector
|
||||||
|
|
||||||
cdef class Candidate:
|
cdef class Candidate:
|
||||||
|
|
||||||
def __init__(self, KnowledgeBase kb, entity_hash, alias_hash, prior_prob):
|
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, alias_hash, prior_prob):
|
||||||
self.kb = kb
|
self.kb = kb
|
||||||
self.entity_hash = entity_hash
|
self.entity_hash = entity_hash
|
||||||
|
self.entity_freq = entity_freq
|
||||||
self.alias_hash = alias_hash
|
self.alias_hash = alias_hash
|
||||||
self.prior_prob = prior_prob
|
self.prior_prob = prior_prob
|
||||||
|
|
||||||
|
@ -52,6 +53,10 @@ cdef class Candidate:
|
||||||
"""RETURNS (unicode): ID of the original alias"""
|
"""RETURNS (unicode): ID of the original alias"""
|
||||||
return self.kb.vocab.strings[self.alias_hash]
|
return self.kb.vocab.strings[self.alias_hash]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entity_freq(self):
|
||||||
|
return self.entity_freq
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prior_prob(self):
|
def prior_prob(self):
|
||||||
return self.prior_prob
|
return self.prior_prob
|
||||||
|
@ -156,6 +161,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
return [Candidate(kb=self,
|
return [Candidate(kb=self,
|
||||||
entity_hash=self._entries[entry_index].entity_hash,
|
entity_hash=self._entries[entry_index].entity_hash,
|
||||||
|
entity_freq=self._entries[entry_index].prob,
|
||||||
alias_hash=alias_hash,
|
alias_hash=alias_hash,
|
||||||
prior_prob=prob)
|
prior_prob=prob)
|
||||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||||
|
|
64
spacy/tests/serialize/test_serialize_kb.py
Normal file
64
spacy/tests/serialize/test_serialize_kb.py
Normal file
|
@ -0,0 +1,64 @@
|
||||||
|
from ..util import make_tempdir
|
||||||
|
from ...util import ensure_path
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_kb_disk(en_vocab):
|
||||||
|
kb1 = KnowledgeBase(vocab=en_vocab)
|
||||||
|
|
||||||
|
kb1.add_entity(entity="Q53", prob=0.33)
|
||||||
|
kb1.add_entity(entity="Q17", prob=0.2)
|
||||||
|
kb1.add_entity(entity="Q007", prob=0.7)
|
||||||
|
kb1.add_entity(entity="Q44", prob=0.4)
|
||||||
|
kb1.add_alias(alias="double07", entities=["Q17", "Q007"], probabilities=[0.1, 0.9])
|
||||||
|
kb1.add_alias(alias="guy", entities=["Q53", "Q007", "Q17", "Q44"], probabilities=[0.3, 0.3, 0.2, 0.1])
|
||||||
|
kb1.add_alias(alias="random", entities=["Q007"], probabilities=[1.0])
|
||||||
|
|
||||||
|
# baseline assertions
|
||||||
|
_check_kb(kb1)
|
||||||
|
|
||||||
|
# dumping to file & loading back in
|
||||||
|
with make_tempdir() as d:
|
||||||
|
dir_path = ensure_path(d)
|
||||||
|
if not dir_path.exists():
|
||||||
|
dir_path.mkdir()
|
||||||
|
file_path = dir_path / "kb"
|
||||||
|
print(file_path, type(file_path))
|
||||||
|
kb1.dump(str(file_path))
|
||||||
|
|
||||||
|
kb2 = KnowledgeBase(vocab=en_vocab)
|
||||||
|
kb2.load_bulk(str(file_path))
|
||||||
|
|
||||||
|
# final assertions
|
||||||
|
_check_kb(kb2)
|
||||||
|
|
||||||
|
|
||||||
|
def _check_kb(kb):
|
||||||
|
# check entities
|
||||||
|
assert kb.get_size_entities() == 4
|
||||||
|
for entity_string in ["Q53", "Q17", "Q007", "Q44"]:
|
||||||
|
assert entity_string in kb.get_entity_strings()
|
||||||
|
for entity_string in ["", "Q0"]:
|
||||||
|
assert entity_string not in kb.get_entity_strings()
|
||||||
|
|
||||||
|
# check aliases
|
||||||
|
assert kb.get_size_aliases() == 3
|
||||||
|
for alias_string in ["double07", "guy", "random"]:
|
||||||
|
assert alias_string in kb.get_alias_strings()
|
||||||
|
for alias_string in ["nothingness", "", "randomnoise"]:
|
||||||
|
assert alias_string not in kb.get_alias_strings()
|
||||||
|
|
||||||
|
# check candidates & probabilities
|
||||||
|
candidates = sorted(kb.get_candidates("double07"), key=lambda x: x.entity_)
|
||||||
|
assert len(candidates) == 2
|
||||||
|
|
||||||
|
assert candidates[0].entity_ == "Q007"
|
||||||
|
assert candidates[0].entity_freq < 0.701 and candidates[0].entity_freq > 0.699
|
||||||
|
assert candidates[0].alias_ == "double07"
|
||||||
|
assert candidates[0].prior_prob < 0.901 and candidates[0].prior_prob > 0.899
|
||||||
|
|
||||||
|
assert candidates[1].entity_ == "Q17"
|
||||||
|
assert candidates[1].entity_freq < 0.201 and candidates[1].entity_freq > 0.199
|
||||||
|
assert candidates[1].alias_ == "double07"
|
||||||
|
assert candidates[1].prior_prob < 0.101 and candidates[1].prior_prob > 0.099
|
Loading…
Reference in New Issue
Block a user