use nlp's vocab for stringstore

This commit is contained in:
svlandeg 2019-03-21 23:17:25 +01:00
parent 6e2433b95e
commit 4820b43313
4 changed files with 43 additions and 37 deletions

View File

@ -6,8 +6,8 @@ import spacy
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
def create_kb(): def create_kb(vocab):
kb = KnowledgeBase() kb = KnowledgeBase(vocab=vocab)
# adding entities # adding entities
entity_0 = "Q1004791" entity_0 = "Q1004791"
@ -25,11 +25,11 @@ def create_kb():
# adding aliases # adding aliases
print() print()
alias_0 = "Douglas" alias_0 = "Douglas"
print("adding alias", alias_0, "to all three entities") print("adding alias", alias_0)
kb.add_alias(alias=alias_0, entities=["Q1004791", "Q42", "Q5301561"], probabilities=[0.1, 0.6, 0.2]) kb.add_alias(alias=alias_0, entities=["Q1004791", "Q42", "Q5301561"], probabilities=[0.1, 0.6, 0.2])
alias_1 = "Douglas Adams" alias_1 = "Douglas Adams"
print("adding alias", alias_1, "to just the one entity") print("adding alias", alias_1)
kb.add_alias(alias=alias_1, entities=["Q42"], probabilities=[0.9]) kb.add_alias(alias=alias_1, entities=["Q42"], probabilities=[0.9])
print() print()
@ -38,9 +38,7 @@ def create_kb():
return kb return kb
def add_el(kb): def add_el(kb, nlp):
nlp = spacy.load('en_core_web_sm')
el_pipe = nlp.create_pipe(name='el', config={"kb": kb}) el_pipe = nlp.create_pipe(name='el', config={"kb": kb})
nlp.add_pipe(el_pipe, last=True) nlp.add_pipe(el_pipe, last=True)
@ -49,10 +47,11 @@ def add_el(kb):
print() print()
print(len(candidates), "candidate(s) for", alias, ":") print(len(candidates), "candidate(s) for", alias, ":")
for c in candidates: for c in candidates:
print(" ", c.entity_id_, c.entity_name_, c.alias_, c.prior_prob) print(" ", c.entity_id_, c.entity_name_, c.prior_prob)
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel." "Douglas reminds us to always bring our towel. " \
"The main character in Doug's novel is called Arthur Dent."
doc = nlp(text) doc = nlp(text)
print() print()
@ -65,5 +64,6 @@ def add_el(kb):
if __name__ == "__main__": if __name__ == "__main__":
mykb = create_kb() nlp = spacy.load('en_core_web_sm')
add_el(mykb) my_kb = create_kb(nlp.vocab)
add_el(my_kb, nlp)

View File

@ -4,7 +4,7 @@ from preshed.maps cimport PreshMap
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.stdint cimport int32_t, int64_t from libc.stdint cimport int32_t, int64_t
from spacy.strings cimport StringStore from spacy.vocab cimport Vocab
from .typedefs cimport hash_t from .typedefs cimport hash_t
@ -55,7 +55,7 @@ cdef class Candidate:
cdef class KnowledgeBase: cdef class KnowledgeBase:
cdef Pool mem cdef Pool mem
cpdef readonly StringStore strings cpdef readonly Vocab vocab
# This maps 64bit keys (hash of unique entity string) # This maps 64bit keys (hash of unique entity string)
# to 64bit values (position of the _EntryC struct in the _entries vector). # to 64bit values (position of the _EntryC struct in the _entries vector).
@ -133,11 +133,11 @@ cdef class KnowledgeBase:
cf. https://github.com/explosion/preshed/issues/17 cf. https://github.com/explosion/preshed/issues/17
""" """
cdef int32_t dummy_value = 0 cdef int32_t dummy_value = 0
self.strings.add("") self.vocab.strings.add("")
self._entries.push_back( self._entries.push_back(
_EntryC( _EntryC(
entity_id_hash=self.strings[""], entity_id_hash=self.vocab.strings[""],
entity_name_hash=self.strings[""], entity_name_hash=self.vocab.strings[""],
vector_rows=&dummy_value, vector_rows=&dummy_value,
feats_row=dummy_value, feats_row=dummy_value,
prob=dummy_value prob=dummy_value

View File

@ -19,7 +19,7 @@ cdef class Candidate:
property entity_id_: property entity_id_:
"""RETURNS (unicode): ID of this entity in the KB""" """RETURNS (unicode): ID of this entity in the KB"""
def __get__(self): def __get__(self):
return self.kb.strings[self.entity_id] return self.kb.vocab.strings[self.entity_id]
property entity_name: property entity_name:
"""RETURNS (uint64): hash of the entity's KB name""" """RETURNS (uint64): hash of the entity's KB name"""
@ -30,7 +30,7 @@ cdef class Candidate:
property entity_name_: property entity_name_:
"""RETURNS (unicode): name of this entity in the KB""" """RETURNS (unicode): name of this entity in the KB"""
def __get__(self): def __get__(self):
return self.kb.strings[self.entity_name] return self.kb.vocab.strings[self.entity_name]
property alias: property alias:
"""RETURNS (uint64): hash of the alias""" """RETURNS (uint64): hash of the alias"""
@ -40,7 +40,7 @@ cdef class Candidate:
property alias_: property alias_:
"""RETURNS (unicode): ID of the original alias""" """RETURNS (unicode): ID of the original alias"""
def __get__(self): def __get__(self):
return self.kb.strings[self.alias] return self.kb.vocab.strings[self.alias]
property prior_prob: property prior_prob:
def __get__(self): def __get__(self):
@ -49,11 +49,11 @@ cdef class Candidate:
cdef class KnowledgeBase: cdef class KnowledgeBase:
def __init__(self): def __init__(self, Vocab vocab):
self.vocab = vocab
self._entry_index = PreshMap() self._entry_index = PreshMap()
self._alias_index = PreshMap() self._alias_index = PreshMap()
self.mem = Pool() self.mem = Pool()
self.strings = StringStore()
self._create_empty_vectors() self._create_empty_vectors()
def __len__(self): def __len__(self):
@ -72,8 +72,8 @@ cdef class KnowledgeBase:
""" """
if not entity_name: if not entity_name:
entity_name = entity_id entity_name = entity_id
cdef hash_t id_hash = self.strings.add(entity_id) cdef hash_t id_hash = self.vocab.strings.add(entity_id)
cdef hash_t name_hash = self.strings.add(entity_name) cdef hash_t name_hash = self.vocab.strings.add(entity_name)
# Return if this entity was added before # Return if this entity was added before
if id_hash in self._entry_index: if id_hash in self._entry_index:
@ -107,7 +107,7 @@ cdef class KnowledgeBase:
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, " raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
+ "but found " + str(prob_sum)) + "but found " + str(prob_sum))
cdef hash_t alias_hash = self.strings.add(alias) cdef hash_t alias_hash = self.vocab.strings.add(alias)
# Return if this alias was added before # Return if this alias was added before
if alias_hash in self._alias_index: if alias_hash in self._alias_index:
@ -120,7 +120,7 @@ cdef class KnowledgeBase:
cdef vector[float] probs cdef vector[float] probs
for entity, prob in zip(entities, probabilities): for entity, prob in zip(entities, probabilities):
entity_id_hash = self.strings[entity] entity_id_hash = self.vocab.strings[entity]
if not entity_id_hash in self._entry_index: if not entity_id_hash in self._entry_index:
raise ValueError("Alias '" + alias + "' defined for unknown entity '" + entity + "'") raise ValueError("Alias '" + alias + "' defined for unknown entity '" + entity + "'")
@ -135,7 +135,7 @@ cdef class KnowledgeBase:
def get_candidates(self, unicode alias): def get_candidates(self, unicode alias):
""" TODO: where to put this functionality ?""" """ TODO: where to put this functionality ?"""
cdef hash_t alias_hash = self.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
alias_index = <int64_t>self._alias_index.get(alias_hash) alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]

View File

@ -2,11 +2,17 @@
import pytest import pytest
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from spacy.lang.en import English
def test_kb_valid_entities(): @pytest.fixture
"""Test the valid construction of a KB with 3 entities and one alias""" def nlp():
mykb = KnowledgeBase() return English()
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity_id="Q1", prob=0.9) mykb.add_entity(entity_id="Q1", prob=0.9)
@ -22,9 +28,9 @@ def test_kb_valid_entities():
assert(mykb.get_size_aliases() == 2) assert(mykb.get_size_aliases() == 2)
def test_kb_invalid_entities(): def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity""" """Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase() mykb = KnowledgeBase(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity_id="Q1", prob=0.9) mykb.add_entity(entity_id="Q1", prob=0.9)
@ -36,9 +42,9 @@ def test_kb_invalid_entities():
mykb.add_alias(alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]) mykb.add_alias(alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
def test_kb_invalid_probabilities(): def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities""" """Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase() mykb = KnowledgeBase(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity_id="Q1", prob=0.9) mykb.add_entity(entity_id="Q1", prob=0.9)
@ -50,9 +56,9 @@ def test_kb_invalid_probabilities():
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4]) mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
def test_kb_invalid_combination(): def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists""" """Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase() mykb = KnowledgeBase(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity_id="Q1", prob=0.9) mykb.add_entity(entity_id="Q1", prob=0.9)
@ -64,9 +70,9 @@ def test_kb_invalid_combination():
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]) mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1])
def test_candidate_generation(): def test_candidate_generation(nlp):
"""Test correct candidate generation""" """Test correct candidate generation"""
mykb = KnowledgeBase() mykb = KnowledgeBase(nlp.vocab)
# adding entities # adding entities
mykb.add_entity(entity_id="Q1", prob=0.9) mykb.add_entity(entity_id="Q1", prob=0.9)