mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
use nlp's vocab for stringstore
This commit is contained in:
parent
6e2433b95e
commit
4820b43313
|
@ -6,8 +6,8 @@ import spacy
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
|
|
||||||
def create_kb():
|
def create_kb(vocab):
|
||||||
kb = KnowledgeBase()
|
kb = KnowledgeBase(vocab=vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
entity_0 = "Q1004791"
|
entity_0 = "Q1004791"
|
||||||
|
@ -25,11 +25,11 @@ def create_kb():
|
||||||
# adding aliases
|
# adding aliases
|
||||||
print()
|
print()
|
||||||
alias_0 = "Douglas"
|
alias_0 = "Douglas"
|
||||||
print("adding alias", alias_0, "to all three entities")
|
print("adding alias", alias_0)
|
||||||
kb.add_alias(alias=alias_0, entities=["Q1004791", "Q42", "Q5301561"], probabilities=[0.1, 0.6, 0.2])
|
kb.add_alias(alias=alias_0, entities=["Q1004791", "Q42", "Q5301561"], probabilities=[0.1, 0.6, 0.2])
|
||||||
|
|
||||||
alias_1 = "Douglas Adams"
|
alias_1 = "Douglas Adams"
|
||||||
print("adding alias", alias_1, "to just the one entity")
|
print("adding alias", alias_1)
|
||||||
kb.add_alias(alias=alias_1, entities=["Q42"], probabilities=[0.9])
|
kb.add_alias(alias=alias_1, entities=["Q42"], probabilities=[0.9])
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
@ -38,9 +38,7 @@ def create_kb():
|
||||||
return kb
|
return kb
|
||||||
|
|
||||||
|
|
||||||
def add_el(kb):
|
def add_el(kb, nlp):
|
||||||
nlp = spacy.load('en_core_web_sm')
|
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='el', config={"kb": kb})
|
el_pipe = nlp.create_pipe(name='el', config={"kb": kb})
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
@ -49,10 +47,11 @@ def add_el(kb):
|
||||||
print()
|
print()
|
||||||
print(len(candidates), "candidate(s) for", alias, ":")
|
print(len(candidates), "candidate(s) for", alias, ":")
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
print(" ", c.entity_id_, c.entity_name_, c.alias_, c.prior_prob)
|
print(" ", c.entity_id_, c.entity_name_, c.prior_prob)
|
||||||
|
|
||||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||||
"Douglas reminds us to always bring our towel."
|
"Douglas reminds us to always bring our towel. " \
|
||||||
|
"The main character in Doug's novel is called Arthur Dent."
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
|
@ -65,5 +64,6 @@ def add_el(kb):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
mykb = create_kb()
|
nlp = spacy.load('en_core_web_sm')
|
||||||
add_el(mykb)
|
my_kb = create_kb(nlp.vocab)
|
||||||
|
add_el(my_kb, nlp)
|
||||||
|
|
10
spacy/kb.pxd
10
spacy/kb.pxd
|
@ -4,7 +4,7 @@ from preshed.maps cimport PreshMap
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libc.stdint cimport int32_t, int64_t
|
from libc.stdint cimport int32_t, int64_t
|
||||||
|
|
||||||
from spacy.strings cimport StringStore
|
from spacy.vocab cimport Vocab
|
||||||
from .typedefs cimport hash_t
|
from .typedefs cimport hash_t
|
||||||
|
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ cdef class Candidate:
|
||||||
|
|
||||||
cdef class KnowledgeBase:
|
cdef class KnowledgeBase:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cpdef readonly StringStore strings
|
cpdef readonly Vocab vocab
|
||||||
|
|
||||||
# This maps 64bit keys (hash of unique entity string)
|
# This maps 64bit keys (hash of unique entity string)
|
||||||
# to 64bit values (position of the _EntryC struct in the _entries vector).
|
# to 64bit values (position of the _EntryC struct in the _entries vector).
|
||||||
|
@ -133,11 +133,11 @@ cdef class KnowledgeBase:
|
||||||
cf. https://github.com/explosion/preshed/issues/17
|
cf. https://github.com/explosion/preshed/issues/17
|
||||||
"""
|
"""
|
||||||
cdef int32_t dummy_value = 0
|
cdef int32_t dummy_value = 0
|
||||||
self.strings.add("")
|
self.vocab.strings.add("")
|
||||||
self._entries.push_back(
|
self._entries.push_back(
|
||||||
_EntryC(
|
_EntryC(
|
||||||
entity_id_hash=self.strings[""],
|
entity_id_hash=self.vocab.strings[""],
|
||||||
entity_name_hash=self.strings[""],
|
entity_name_hash=self.vocab.strings[""],
|
||||||
vector_rows=&dummy_value,
|
vector_rows=&dummy_value,
|
||||||
feats_row=dummy_value,
|
feats_row=dummy_value,
|
||||||
prob=dummy_value
|
prob=dummy_value
|
||||||
|
|
20
spacy/kb.pyx
20
spacy/kb.pyx
|
@ -19,7 +19,7 @@ cdef class Candidate:
|
||||||
property entity_id_:
|
property entity_id_:
|
||||||
"""RETURNS (unicode): ID of this entity in the KB"""
|
"""RETURNS (unicode): ID of this entity in the KB"""
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.kb.strings[self.entity_id]
|
return self.kb.vocab.strings[self.entity_id]
|
||||||
|
|
||||||
property entity_name:
|
property entity_name:
|
||||||
"""RETURNS (uint64): hash of the entity's KB name"""
|
"""RETURNS (uint64): hash of the entity's KB name"""
|
||||||
|
@ -30,7 +30,7 @@ cdef class Candidate:
|
||||||
property entity_name_:
|
property entity_name_:
|
||||||
"""RETURNS (unicode): name of this entity in the KB"""
|
"""RETURNS (unicode): name of this entity in the KB"""
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.kb.strings[self.entity_name]
|
return self.kb.vocab.strings[self.entity_name]
|
||||||
|
|
||||||
property alias:
|
property alias:
|
||||||
"""RETURNS (uint64): hash of the alias"""
|
"""RETURNS (uint64): hash of the alias"""
|
||||||
|
@ -40,7 +40,7 @@ cdef class Candidate:
|
||||||
property alias_:
|
property alias_:
|
||||||
"""RETURNS (unicode): ID of the original alias"""
|
"""RETURNS (unicode): ID of the original alias"""
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
return self.kb.strings[self.alias]
|
return self.kb.vocab.strings[self.alias]
|
||||||
|
|
||||||
property prior_prob:
|
property prior_prob:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
@ -49,11 +49,11 @@ cdef class Candidate:
|
||||||
|
|
||||||
cdef class KnowledgeBase:
|
cdef class KnowledgeBase:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self, Vocab vocab):
|
||||||
|
self.vocab = vocab
|
||||||
self._entry_index = PreshMap()
|
self._entry_index = PreshMap()
|
||||||
self._alias_index = PreshMap()
|
self._alias_index = PreshMap()
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.strings = StringStore()
|
|
||||||
self._create_empty_vectors()
|
self._create_empty_vectors()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
|
@ -72,8 +72,8 @@ cdef class KnowledgeBase:
|
||||||
"""
|
"""
|
||||||
if not entity_name:
|
if not entity_name:
|
||||||
entity_name = entity_id
|
entity_name = entity_id
|
||||||
cdef hash_t id_hash = self.strings.add(entity_id)
|
cdef hash_t id_hash = self.vocab.strings.add(entity_id)
|
||||||
cdef hash_t name_hash = self.strings.add(entity_name)
|
cdef hash_t name_hash = self.vocab.strings.add(entity_name)
|
||||||
|
|
||||||
# Return if this entity was added before
|
# Return if this entity was added before
|
||||||
if id_hash in self._entry_index:
|
if id_hash in self._entry_index:
|
||||||
|
@ -107,7 +107,7 @@ cdef class KnowledgeBase:
|
||||||
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
|
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
|
||||||
+ "but found " + str(prob_sum))
|
+ "but found " + str(prob_sum))
|
||||||
|
|
||||||
cdef hash_t alias_hash = self.strings.add(alias)
|
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||||
|
|
||||||
# Return if this alias was added before
|
# Return if this alias was added before
|
||||||
if alias_hash in self._alias_index:
|
if alias_hash in self._alias_index:
|
||||||
|
@ -120,7 +120,7 @@ cdef class KnowledgeBase:
|
||||||
cdef vector[float] probs
|
cdef vector[float] probs
|
||||||
|
|
||||||
for entity, prob in zip(entities, probabilities):
|
for entity, prob in zip(entities, probabilities):
|
||||||
entity_id_hash = self.strings[entity]
|
entity_id_hash = self.vocab.strings[entity]
|
||||||
if not entity_id_hash in self._entry_index:
|
if not entity_id_hash in self._entry_index:
|
||||||
raise ValueError("Alias '" + alias + "' defined for unknown entity '" + entity + "'")
|
raise ValueError("Alias '" + alias + "' defined for unknown entity '" + entity + "'")
|
||||||
|
|
||||||
|
@ -135,7 +135,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
def get_candidates(self, unicode alias):
|
def get_candidates(self, unicode alias):
|
||||||
""" TODO: where to put this functionality ?"""
|
""" TODO: where to put this functionality ?"""
|
||||||
cdef hash_t alias_hash = self.strings[alias]
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||||
alias_entry = self._aliases_table[alias_index]
|
alias_entry = self._aliases_table[alias_index]
|
||||||
|
|
||||||
|
|
|
@ -2,11 +2,17 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
|
from spacy.lang.en import English
|
||||||
|
|
||||||
|
|
||||||
def test_kb_valid_entities():
|
@pytest.fixture
|
||||||
"""Test the valid construction of a KB with 3 entities and one alias"""
|
def nlp():
|
||||||
mykb = KnowledgeBase()
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_valid_entities(nlp):
|
||||||
|
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
|
@ -22,9 +28,9 @@ def test_kb_valid_entities():
|
||||||
assert(mykb.get_size_aliases() == 2)
|
assert(mykb.get_size_aliases() == 2)
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entities():
|
def test_kb_invalid_entities(nlp):
|
||||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||||
mykb = KnowledgeBase()
|
mykb = KnowledgeBase(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
|
@ -36,9 +42,9 @@ def test_kb_invalid_entities():
|
||||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_probabilities():
|
def test_kb_invalid_probabilities(nlp):
|
||||||
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||||
mykb = KnowledgeBase()
|
mykb = KnowledgeBase(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
|
@ -50,9 +56,9 @@ def test_kb_invalid_probabilities():
|
||||||
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_combination():
|
def test_kb_invalid_combination(nlp):
|
||||||
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
||||||
mykb = KnowledgeBase()
|
mykb = KnowledgeBase(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
|
@ -64,9 +70,9 @@ def test_kb_invalid_combination():
|
||||||
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1])
|
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1])
|
||||||
|
|
||||||
|
|
||||||
def test_candidate_generation():
|
def test_candidate_generation(nlp):
|
||||||
"""Test correct candidate generation"""
|
"""Test correct candidate generation"""
|
||||||
mykb = KnowledgeBase()
|
mykb = KnowledgeBase(nlp.vocab)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user