mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
use nlp's vocab for stringstore
This commit is contained in:
parent
6e2433b95e
commit
4820b43313
|
@ -6,8 +6,8 @@ import spacy
|
|||
from spacy.kb import KnowledgeBase
|
||||
|
||||
|
||||
def create_kb():
|
||||
kb = KnowledgeBase()
|
||||
def create_kb(vocab):
|
||||
kb = KnowledgeBase(vocab=vocab)
|
||||
|
||||
# adding entities
|
||||
entity_0 = "Q1004791"
|
||||
|
@ -25,11 +25,11 @@ def create_kb():
|
|||
# adding aliases
|
||||
print()
|
||||
alias_0 = "Douglas"
|
||||
print("adding alias", alias_0, "to all three entities")
|
||||
print("adding alias", alias_0)
|
||||
kb.add_alias(alias=alias_0, entities=["Q1004791", "Q42", "Q5301561"], probabilities=[0.1, 0.6, 0.2])
|
||||
|
||||
alias_1 = "Douglas Adams"
|
||||
print("adding alias", alias_1, "to just the one entity")
|
||||
print("adding alias", alias_1)
|
||||
kb.add_alias(alias=alias_1, entities=["Q42"], probabilities=[0.9])
|
||||
|
||||
print()
|
||||
|
@ -38,9 +38,7 @@ def create_kb():
|
|||
return kb
|
||||
|
||||
|
||||
def add_el(kb):
|
||||
nlp = spacy.load('en_core_web_sm')
|
||||
|
||||
def add_el(kb, nlp):
|
||||
el_pipe = nlp.create_pipe(name='el', config={"kb": kb})
|
||||
nlp.add_pipe(el_pipe, last=True)
|
||||
|
||||
|
@ -49,10 +47,11 @@ def add_el(kb):
|
|||
print()
|
||||
print(len(candidates), "candidate(s) for", alias, ":")
|
||||
for c in candidates:
|
||||
print(" ", c.entity_id_, c.entity_name_, c.alias_, c.prior_prob)
|
||||
print(" ", c.entity_id_, c.entity_name_, c.prior_prob)
|
||||
|
||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||
"Douglas reminds us to always bring our towel."
|
||||
"Douglas reminds us to always bring our towel. " \
|
||||
"The main character in Doug's novel is called Arthur Dent."
|
||||
doc = nlp(text)
|
||||
|
||||
print()
|
||||
|
@ -65,5 +64,6 @@ def add_el(kb):
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mykb = create_kb()
|
||||
add_el(mykb)
|
||||
nlp = spacy.load('en_core_web_sm')
|
||||
my_kb = create_kb(nlp.vocab)
|
||||
add_el(my_kb, nlp)
|
||||
|
|
10
spacy/kb.pxd
10
spacy/kb.pxd
|
@ -4,7 +4,7 @@ from preshed.maps cimport PreshMap
|
|||
from libcpp.vector cimport vector
|
||||
from libc.stdint cimport int32_t, int64_t
|
||||
|
||||
from spacy.strings cimport StringStore
|
||||
from spacy.vocab cimport Vocab
|
||||
from .typedefs cimport hash_t
|
||||
|
||||
|
||||
|
@ -55,7 +55,7 @@ cdef class Candidate:
|
|||
|
||||
cdef class KnowledgeBase:
|
||||
cdef Pool mem
|
||||
cpdef readonly StringStore strings
|
||||
cpdef readonly Vocab vocab
|
||||
|
||||
# This maps 64bit keys (hash of unique entity string)
|
||||
# to 64bit values (position of the _EntryC struct in the _entries vector).
|
||||
|
@ -133,11 +133,11 @@ cdef class KnowledgeBase:
|
|||
cf. https://github.com/explosion/preshed/issues/17
|
||||
"""
|
||||
cdef int32_t dummy_value = 0
|
||||
self.strings.add("")
|
||||
self.vocab.strings.add("")
|
||||
self._entries.push_back(
|
||||
_EntryC(
|
||||
entity_id_hash=self.strings[""],
|
||||
entity_name_hash=self.strings[""],
|
||||
entity_id_hash=self.vocab.strings[""],
|
||||
entity_name_hash=self.vocab.strings[""],
|
||||
vector_rows=&dummy_value,
|
||||
feats_row=dummy_value,
|
||||
prob=dummy_value
|
||||
|
|
20
spacy/kb.pyx
20
spacy/kb.pyx
|
@ -19,7 +19,7 @@ cdef class Candidate:
|
|||
property entity_id_:
|
||||
"""RETURNS (unicode): ID of this entity in the KB"""
|
||||
def __get__(self):
|
||||
return self.kb.strings[self.entity_id]
|
||||
return self.kb.vocab.strings[self.entity_id]
|
||||
|
||||
property entity_name:
|
||||
"""RETURNS (uint64): hash of the entity's KB name"""
|
||||
|
@ -30,7 +30,7 @@ cdef class Candidate:
|
|||
property entity_name_:
|
||||
"""RETURNS (unicode): name of this entity in the KB"""
|
||||
def __get__(self):
|
||||
return self.kb.strings[self.entity_name]
|
||||
return self.kb.vocab.strings[self.entity_name]
|
||||
|
||||
property alias:
|
||||
"""RETURNS (uint64): hash of the alias"""
|
||||
|
@ -40,7 +40,7 @@ cdef class Candidate:
|
|||
property alias_:
|
||||
"""RETURNS (unicode): ID of the original alias"""
|
||||
def __get__(self):
|
||||
return self.kb.strings[self.alias]
|
||||
return self.kb.vocab.strings[self.alias]
|
||||
|
||||
property prior_prob:
|
||||
def __get__(self):
|
||||
|
@ -49,11 +49,11 @@ cdef class Candidate:
|
|||
|
||||
cdef class KnowledgeBase:
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, Vocab vocab):
|
||||
self.vocab = vocab
|
||||
self._entry_index = PreshMap()
|
||||
self._alias_index = PreshMap()
|
||||
self.mem = Pool()
|
||||
self.strings = StringStore()
|
||||
self._create_empty_vectors()
|
||||
|
||||
def __len__(self):
|
||||
|
@ -72,8 +72,8 @@ cdef class KnowledgeBase:
|
|||
"""
|
||||
if not entity_name:
|
||||
entity_name = entity_id
|
||||
cdef hash_t id_hash = self.strings.add(entity_id)
|
||||
cdef hash_t name_hash = self.strings.add(entity_name)
|
||||
cdef hash_t id_hash = self.vocab.strings.add(entity_id)
|
||||
cdef hash_t name_hash = self.vocab.strings.add(entity_name)
|
||||
|
||||
# Return if this entity was added before
|
||||
if id_hash in self._entry_index:
|
||||
|
@ -107,7 +107,7 @@ cdef class KnowledgeBase:
|
|||
raise ValueError("The sum of prior probabilities for alias '" + alias + "' should not exceed 1, "
|
||||
+ "but found " + str(prob_sum))
|
||||
|
||||
cdef hash_t alias_hash = self.strings.add(alias)
|
||||
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||
|
||||
# Return if this alias was added before
|
||||
if alias_hash in self._alias_index:
|
||||
|
@ -120,7 +120,7 @@ cdef class KnowledgeBase:
|
|||
cdef vector[float] probs
|
||||
|
||||
for entity, prob in zip(entities, probabilities):
|
||||
entity_id_hash = self.strings[entity]
|
||||
entity_id_hash = self.vocab.strings[entity]
|
||||
if not entity_id_hash in self._entry_index:
|
||||
raise ValueError("Alias '" + alias + "' defined for unknown entity '" + entity + "'")
|
||||
|
||||
|
@ -135,7 +135,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
def get_candidates(self, unicode alias):
|
||||
""" TODO: where to put this functionality ?"""
|
||||
cdef hash_t alias_hash = self.strings[alias]
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
|
||||
|
|
|
@ -2,11 +2,17 @@
|
|||
import pytest
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from spacy.lang.en import English
|
||||
|
||||
|
||||
def test_kb_valid_entities():
|
||||
"""Test the valid construction of a KB with 3 entities and one alias"""
|
||||
mykb = KnowledgeBase()
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return English()
|
||||
|
||||
|
||||
def test_kb_valid_entities(nlp):
|
||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
|
@ -22,9 +28,9 @@ def test_kb_valid_entities():
|
|||
assert(mykb.get_size_aliases() == 2)
|
||||
|
||||
|
||||
def test_kb_invalid_entities():
|
||||
def test_kb_invalid_entities(nlp):
|
||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||
mykb = KnowledgeBase()
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
|
@ -36,9 +42,9 @@ def test_kb_invalid_entities():
|
|||
mykb.add_alias(alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2])
|
||||
|
||||
|
||||
def test_kb_invalid_probabilities():
|
||||
def test_kb_invalid_probabilities(nlp):
|
||||
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||
mykb = KnowledgeBase()
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
|
@ -50,9 +56,9 @@ def test_kb_invalid_probabilities():
|
|||
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||
|
||||
|
||||
def test_kb_invalid_combination():
|
||||
def test_kb_invalid_combination(nlp):
|
||||
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
||||
mykb = KnowledgeBase()
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
|
@ -64,9 +70,9 @@ def test_kb_invalid_combination():
|
|||
mykb.add_alias(alias="douglassss", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1])
|
||||
|
||||
|
||||
def test_candidate_generation():
|
||||
def test_candidate_generation(nlp):
|
||||
"""Test correct candidate generation"""
|
||||
mykb = KnowledgeBase()
|
||||
mykb = KnowledgeBase(nlp.vocab)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity_id="Q1", prob=0.9)
|
||||
|
|
Loading…
Reference in New Issue
Block a user