diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py index a8a3eec1e..3b0943167 100644 --- a/examples/pipeline/wikidata_entity_linking.py +++ b/examples/pipeline/wikidata_entity_linking.py @@ -442,11 +442,11 @@ if __name__ == "__main__": print() print("dumping kb1") + print(KB_FILE, type(KB_FILE)) kb1.dump(KB_FILE) # STEP 4 : read KB back in from file - nlp3 = spacy.load('en_core_web_sm') kb3 = KnowledgeBase(vocab=my_vocab) print("loading kb3") diff --git a/spacy/kb.pxd b/spacy/kb.pxd index 5f7bfa46c..82b06d192 100644 --- a/spacy/kb.pxd +++ b/spacy/kb.pxd @@ -19,6 +19,7 @@ cdef class Candidate: cdef readonly KnowledgeBase kb cdef hash_t entity_hash + cdef float entity_freq cdef hash_t alias_hash cdef float prior_prob diff --git a/spacy/kb.pyx b/spacy/kb.pyx index f3d5ecaa9..ad2e13b5e 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -26,9 +26,10 @@ from libcpp.vector cimport vector cdef class Candidate: - def __init__(self, KnowledgeBase kb, entity_hash, alias_hash, prior_prob): + def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, alias_hash, prior_prob): self.kb = kb self.entity_hash = entity_hash + self.entity_freq = entity_freq self.alias_hash = alias_hash self.prior_prob = prior_prob @@ -52,6 +53,10 @@ cdef class Candidate: """RETURNS (unicode): ID of the original alias""" return self.kb.vocab.strings[self.alias_hash] + @property + def entity_freq(self): + return self.entity_freq + @property def prior_prob(self): return self.prior_prob @@ -156,6 +161,7 @@ cdef class KnowledgeBase: return [Candidate(kb=self, entity_hash=self._entries[entry_index].entity_hash, + entity_freq=self._entries[entry_index].prob, alias_hash=alias_hash, prior_prob=prob) for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs) diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py new file mode 100644 index 000000000..ae0eedeeb --- /dev/null +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -0,0 +1,64 @@ +from ..util import make_tempdir +from ...util import ensure_path + +from spacy.kb import KnowledgeBase + + +def test_serialize_kb_disk(en_vocab): + kb1 = KnowledgeBase(vocab=en_vocab) + + kb1.add_entity(entity="Q53", prob=0.33) + kb1.add_entity(entity="Q17", prob=0.2) + kb1.add_entity(entity="Q007", prob=0.7) + kb1.add_entity(entity="Q44", prob=0.4) + kb1.add_alias(alias="double07", entities=["Q17", "Q007"], probabilities=[0.1, 0.9]) + kb1.add_alias(alias="guy", entities=["Q53", "Q007", "Q17", "Q44"], probabilities=[0.3, 0.3, 0.2, 0.1]) + kb1.add_alias(alias="random", entities=["Q007"], probabilities=[1.0]) + + # baseline assertions + _check_kb(kb1) + + # dumping to file & loading back in + with make_tempdir() as d: + dir_path = ensure_path(d) + if not dir_path.exists(): + dir_path.mkdir() + file_path = dir_path / "kb" + print(file_path, type(file_path)) + kb1.dump(str(file_path)) + + kb2 = KnowledgeBase(vocab=en_vocab) + kb2.load_bulk(str(file_path)) + + # final assertions + _check_kb(kb2) + + +def _check_kb(kb): + # check entities + assert kb.get_size_entities() == 4 + for entity_string in ["Q53", "Q17", "Q007", "Q44"]: + assert entity_string in kb.get_entity_strings() + for entity_string in ["", "Q0"]: + assert entity_string not in kb.get_entity_strings() + + # check aliases + assert kb.get_size_aliases() == 3 + for alias_string in ["double07", "guy", "random"]: + assert alias_string in kb.get_alias_strings() + for alias_string in ["nothingness", "", "randomnoise"]: + assert alias_string not in kb.get_alias_strings() + + # check candidates & probabilities + candidates = sorted(kb.get_candidates("double07"), key=lambda x: x.entity_) + assert len(candidates) == 2 + + assert candidates[0].entity_ == "Q007" + assert candidates[0].entity_freq < 0.701 and candidates[0].entity_freq > 0.699 + assert candidates[0].alias_ == "double07" + assert candidates[0].prior_prob < 0.901 and candidates[0].prior_prob > 0.899 + + assert candidates[1].entity_ == "Q17" + assert candidates[1].entity_freq < 0.201 and candidates[1].entity_freq > 0.199 + assert candidates[1].alias_ == "double07" + assert candidates[1].prior_prob < 0.101 and candidates[1].prior_prob > 0.099