use nlp's vocab for stringstore

This commit is contained in:
svlandeg 2019-03-21 23:17:25 +01:00
parent 1ee0e78fd7
commit a48241e9a2
4 changed files with 43 additions and 37 deletions

View File

@ -6,8 +6,8 @@ import spacy
from spacy.kb import KnowledgeBase
def create_kb():
kb = KnowledgeBase()
def create_kb(vocab):
kb = KnowledgeBase(vocab=vocab)
# adding entities
entity_0 = "Q1004791"
@ -25,11 +25,11 @@ def create_kb():
# adding aliases
print()
alias_0 = "Douglas"
print("adding alias", alias_0, "to all three entities")
print("adding alias", alias_0)
kb.add_alias(alias=alias_0, entities=["Q1004791", "Q42", "Q5301561"], probabilities=[0.1, 0.6, 0.2])
alias_1 = "Douglas Adams"
print("adding alias", alias_1, "to just the one entity")
print("adding alias", alias_1)
kb.add_alias(alias=alias_1, entities=["Q42"], probabilities=[0.9])
print()
@ -38,9 +38,7 @@ def create_kb():
return kb
def add_el(kb):
nlp = spacy.load('en_core_web_sm')
def add_el(kb, nlp):
el_pipe = nlp.create_pipe(name='el', config={"kb": kb})
nlp.add_pipe(el_pipe, last=True)
@ -49,10 +47,11 @@ def add_el(kb):
print()
print(len(candidates), "candidate(s) for", alias, ":")
for c in candidates:
print(" ", c.entity_id_, c.entity_name_, c.alias_, c.prior_prob)
print(" ", c.entity_id_, c.entity_name_, c.prior_prob)
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel."
"Douglas reminds us to always bring our towel. " \
"The main character in Doug's novel is called Arthur Dent."
doc = nlp(text)
print()
@ -65,5 +64,6 @@ def add_el(kb):
if __name__ == "__main__":
mykb = create_kb()
add_el(mykb)
nlp = spacy.load('en_core_web_sm')
my_kb = create_kb(nlp.vocab)
add_el(my_kb, nlp)

View File

@ -4,7 +4,7 @@ from preshed.maps cimport PreshMap
from libcpp.vector cimport vector
from libc.stdint cimport int32_t, int64_t
from spacy.strings cimport StringStore
from spacy.vocab cimport Vocab
from .typedefs cimport hash_t
@ -55,7 +55,7 @@ cdef class Candidate:
cdef class KnowledgeBase:
cdef Pool mem
cpdef readonly StringStore strings
cpdef readonly Vocab vocab
# This maps 64bit keys (hash of unique entity string)
# to 64bit values (position of the _EntryC struct in the _entries vector).
@ -133,11 +133,11 @@ cdef class KnowledgeBase:
cf. https://github.com/explosion/preshed/issues/17
"""
cdef int32_t dummy_value = 0
self.strings.add("")
self.vocab.strings.add("")
self._entries.push_back(
_EntryC(
entity_id_hash=self.strings[""],
entity_name_hash=self.strings[""],
entity_id_hash=self.vocab.strings[""],
entity_name_hash=self.vocab.strings[""],
vector_rows=&dummy_value,
feats_row=dummy_value,
prob=dummy_value

View File

@ -19,7 +19,7 @@ cdef class Candidate:
property entity_id_:
"""RETURNS (unicode): ID of this entity in the KB"""
def __get__(self):
return self.kb.strings[self.entity_id]
return self.kb.vocab.strings[self.entity_id]
property entity_name:
"""RETURNS (uint64): hash of the entity's KB name"""
@ -30,7 +30,7 @@ cdef class Candidate:
property entity_name_:
"""RETURNS (unicode): name of this entity in the KB"""
def __get__(self):
return self.kb.strings[self.entity_name]
return self.kb.vocab.strings[self.entity_name]
property alias:
"""RETURNS (uint64): hash of the alias"""
@ -40,7 +40,7 @@ cdef class Candidate:
property alias_:
"""RETURNS (unicode): ID of the original alias"""
def __get__(self):
return self.kb.strings[self.alias]
return self.kb.vocab.strings[self.alias]
property prior_prob:
def __get__(self):
@ -49,11 +49,11 @@ cdef class Candidate:
cdef class KnowledgeBase:
def __init__(self):
def __init__(self, Vocab vocab):
self.vocab = vocab
self._entry_index = PreshMap()
self._alias_index = PreshMap()
self.mem = Pool()
self.strings = StringStore()
self._create_empty_vectors()
def __len__(self):
@ -72,8 +72,8 @@ cdef class KnowledgeBase:
"""
if not entity_name:
entity_name = entity_id
cdef hash_t id_hash = self.strings.add(entity_id)
cdef hash_t name_hash = self.strings.add(entity_name)
cdef hash_t id_hash = self.vocab.strings.add(entity_id)
cdef hash_t name_hash = self.vocab.strings.add(entity_name)
# Return if this entity was added before
if id_hash in self._entry_index:
@ -107,7 +107,7 @@ cdef class KnowledgeBase:
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
+ "but found " + str(prob_sum))
cdef hash_t alias_hash = self.strings.add(alias)
cdef hash_t alias_hash = self.vocab.strings.add(alias)
# Return if this alias was added before
if alias_hash in self._alias_index:
@ -120,7 +120,7 @@ cdef class KnowledgeBase:
cdef vector[float] probs
for entity, prob in zip(entities, probabilities):
entity_id_hash = self.strings[entity]
entity_id_hash = self.vocab.strings[entity]
if not entity_id_hash in self._entry_index:
raise ValueError("Alias '" + alias + "' defined for unknown entity '" + entity + "'")
@ -135,7 +135,7 @@ cdef class KnowledgeBase:
def get_candidates(self, unicode alias):
""" TODO: where to put this functionality ?"""
cdef hash_t alias_hash = self.strings[alias]
cdef hash_t alias_hash = self.vocab.strings[alias]
alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index]

View File

@ -2,11 +2,17 @@
import pytest
from spacy.kb import KnowledgeBase
from spacy.lang.en import English
def test_kb_valid_entities():
"""Test the valid construction of a KB with 3 entities and one alias"""
mykb = KnowledgeBase()
@pytest.fixture
def nlp():
return English()
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity_id="Q1", prob=0.9)
@ -22,9 +28,9 @@ def test_kb_valid_entities():
assert(mykb.get_size_aliases() == 2)
def test_kb_invalid_entities():
def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase()
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity_id="Q1", prob=0.9)
@ -36,9 +42,9 @@ def test_kb_invalid_entities():
mykb.add_alias(alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
def test_kb_invalid_probabilities():
def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase()
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity_id="Q1", prob=0.9)
@ -50,9 +56,9 @@ def test_kb_invalid_probabilities():
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
def test_kb_invalid_combination():
def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase()
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity_id="Q1", prob=0.9)
@ -64,9 +70,9 @@ def test_kb_invalid_combination():
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1])
def test_candidate_generation():
def test_candidate_generation(nlp):
"""Test correct candidate generation"""
mykb = KnowledgeBase()
mykb = KnowledgeBase(nlp.vocab)
# adding entities
mykb.add_entity(entity_id="Q1", prob=0.9)