entity vectors in the KB + serialization of them

This commit is contained in:
svlandeg 2019-06-05 18:29:18 +02:00
parent 9abbd0899f
commit 5c723c32c3
10 changed files with 223 additions and 94 deletions

View File

@ -9,20 +9,20 @@ from spacy.kb import KnowledgeBase
def create_kb(vocab):
kb = KnowledgeBase(vocab=vocab)
kb = KnowledgeBase(vocab=vocab, entity_vector_length=1)
# adding entities
entity_0 = "Q1004791_Douglas"
print("adding entity", entity_0)
kb.add_entity(entity=entity_0, prob=0.5)
kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0])
entity_1 = "Q42_Douglas_Adams"
print("adding entity", entity_1)
kb.add_entity(entity=entity_1, prob=0.5)
kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1])
entity_2 = "Q5301561_Douglas_Haig"
print("adding entity", entity_2)
kb.add_entity(entity=entity_2, prob=0.5)
kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2])
# adding aliases
print()

View File

@ -16,7 +16,7 @@ def create_kb(vocab, max_entities_per_alias, min_occ,
count_input, prior_prob_input,
to_print=False, write_entity_defs=True):
""" Create the knowledge base from Wikidata entries """
kb = KnowledgeBase(vocab=vocab)
kb = KnowledgeBase(vocab=vocab, entity_vector_length=64) # TODO: entity vectors !
print()
print("1. _read_wikidata_entities", datetime.datetime.now())
@ -38,7 +38,8 @@ def create_kb(vocab, max_entities_per_alias, min_occ,
print()
print("3. adding", len(entity_list), "entities", datetime.datetime.now())
print()
kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=None, feature_list=None)
# TODO: vector_list !
kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=None)
print()
print("4. adding aliases", datetime.datetime.now())

View File

@ -19,7 +19,7 @@ class EntityEncoder:
DROP = 0
EPOCHS = 5
STOP_THRESHOLD = 0.05
STOP_THRESHOLD = 0.1
BATCH_SIZE = 1000
@ -38,6 +38,8 @@ class EntityEncoder:
# TODO: apply and write to file afterwards !
# self._apply_encoder(id_to_descr)
self._test_encoder()
def _train_model(self, entity_descr_output, id_to_descr):
# TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy
@ -111,3 +113,40 @@ class EntityEncoder:
def get_loss(golds, scores):
loss, gradients = get_cossim_loss(scores, golds)
return loss, gradients
def _test_encoder(self):
""" Test encoder on some dummy examples """
desc_A1 = "Fictional character in The Simpsons"
desc_A2 = "Simpsons - fictional human"
desc_A3 = "Fictional character in The Flintstones"
desc_A4 = "Politician from the US"
A1_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A1))])
A2_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A2))])
A3_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A3))])
A4_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A4))])
loss_a1_a1, _ = get_cossim_loss(A1_doc_vector, A1_doc_vector)
loss_a1_a2, _ = get_cossim_loss(A1_doc_vector, A2_doc_vector)
loss_a1_a3, _ = get_cossim_loss(A1_doc_vector, A3_doc_vector)
loss_a1_a4, _ = get_cossim_loss(A1_doc_vector, A4_doc_vector)
print("sim doc A1 A1", loss_a1_a1)
print("sim doc A1 A2", loss_a1_a2)
print("sim doc A1 A3", loss_a1_a3)
print("sim doc A1 A4", loss_a1_a4)
A1_encoded = self.encoder(A1_doc_vector)
A2_encoded = self.encoder(A2_doc_vector)
A3_encoded = self.encoder(A3_doc_vector)
A4_encoded = self.encoder(A4_doc_vector)
loss_a1_a1, _ = get_cossim_loss(A1_encoded, A1_encoded)
loss_a1_a2, _ = get_cossim_loss(A1_encoded, A2_encoded)
loss_a1_a3, _ = get_cossim_loss(A1_encoded, A3_encoded)
loss_a1_a4, _ = get_cossim_loss(A1_encoded, A4_encoded)
print("sim encoded A1 A1", loss_a1_a1)
print("sim encoded A1 A2", loss_a1_a2)
print("sim encoded A1 A3", loss_a1_a3)
print("sim encoded A1 A4", loss_a1_a4)

View File

@ -93,7 +93,7 @@ if __name__ == "__main__":
print("STEP 4: to_read_kb", datetime.datetime.now())
my_vocab = Vocab()
my_vocab.from_disk(VOCAB_DIR)
my_kb = KnowledgeBase(vocab=my_vocab)
my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64) # TODO entity vectors
my_kb.load_bulk(KB_FILE)
print("kb entities:", my_kb.get_size_entities())
print("kb aliases:", my_kb.get_size_aliases())

View File

@ -12,6 +12,8 @@ from .typedefs cimport hash_t
from .structs cimport EntryC, AliasC
ctypedef vector[EntryC] entry_vec
ctypedef vector[AliasC] alias_vec
ctypedef vector[float] float_vec
ctypedef vector[float_vec] float_matrix
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
@ -20,6 +22,7 @@ cdef class Candidate:
cdef readonly KnowledgeBase kb
cdef hash_t entity_hash
cdef float entity_freq
cdef vector[float] entity_vector
cdef hash_t alias_hash
cdef float prior_prob
@ -27,6 +30,7 @@ cdef class Candidate:
cdef class KnowledgeBase:
cdef Pool mem
cpdef readonly Vocab vocab
cdef int64_t entity_vector_length
# This maps 64bit keys (hash of unique entity string)
# to 64bit values (position of the _EntryC struct in the _entries vector).
@ -59,7 +63,7 @@ cdef class KnowledgeBase:
# model, that embeds different features of the entities into vectors. We'll
# still want some per-entity features, like the Wikipedia text or entity
# co-occurrence. Hopefully those vectors can be narrow, e.g. 64 dimensions.
cdef object _vectors_table
cdef float_matrix _vectors_table
# It's very useful to track categorical features, at least for output, even
# if they're not useful in the model itself. For instance, we should be
@ -69,8 +73,15 @@ cdef class KnowledgeBase:
cdef object _features_table
cdef inline int64_t c_add_vector(self, vector[float] entity_vector) nogil:
"""Add an entity vector to the vectors table."""
cdef int64_t new_index = self._vectors_table.size()
self._vectors_table.push_back(entity_vector)
return new_index
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
int32_t* vector_rows, int feats_row) nogil:
int32_t vector_index, int feats_row) nogil:
"""Add an entry to the vector of entries.
After calling this method, make sure to update also the _entry_index using the return value"""
# This is what we'll map the entity hash key to. It's where the entry will sit
@ -80,7 +91,7 @@ cdef class KnowledgeBase:
# Avoid struct initializer to enable nogil, cf https://github.com/cython/cython/issues/1642
cdef EntryC entry
entry.entity_hash = entity_hash
entry.vector_rows = vector_rows
entry.vector_index = vector_index
entry.feats_row = feats_row
entry.prob = prob
@ -113,7 +124,7 @@ cdef class KnowledgeBase:
# Avoid struct initializer to enable nogil
cdef EntryC entry
entry.entity_hash = dummy_hash
entry.vector_rows = &dummy_value
entry.vector_index = dummy_value
entry.feats_row = dummy_value
entry.prob = dummy_value
@ -131,15 +142,16 @@ cdef class KnowledgeBase:
self._aliases_table.push_back(alias)
cpdef load_bulk(self, loc)
cpdef set_entities(self, entity_list, prob_list, vector_list, feature_list)
cpdef set_entities(self, entity_list, prob_list, vector_list)
cpdef set_aliases(self, alias_list, entities_list, probabilities_list)
cdef class Writer:
cdef FILE* _fp
cdef int write_header(self, int64_t nr_entries) except -1
cdef int write_entry(self, hash_t entry_hash, float entry_prob) except -1
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
cdef int write_vector_element(self, float element) except -1
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1
cdef int write_alias_length(self, int64_t alias_length) except -1
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
@ -150,8 +162,9 @@ cdef class Writer:
cdef class Reader:
cdef FILE* _fp
cdef int read_header(self, int64_t* nr_entries) except -1
cdef int read_entry(self, hash_t* entity_hash, float* prob) except -1
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
cdef int read_vector_element(self, float* element) except -1
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1
cdef int read_alias_length(self, int64_t* alias_length) except -1
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1

View File

@ -26,10 +26,11 @@ from libcpp.vector cimport vector
cdef class Candidate:
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, alias_hash, prior_prob):
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob):
self.kb = kb
self.entity_hash = entity_hash
self.entity_freq = entity_freq
self.entity_vector = entity_vector
self.alias_hash = alias_hash
self.prior_prob = prior_prob
@ -57,19 +58,26 @@ cdef class Candidate:
def entity_freq(self):
return self.entity_freq
@property
def entity_vector(self):
return self.entity_vector
@property
def prior_prob(self):
return self.prior_prob
cdef class KnowledgeBase:
def __init__(self, Vocab vocab):
def __init__(self, Vocab vocab, entity_vector_length):
self.vocab = vocab
self.mem = Pool()
self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap()
self._alias_index = PreshMap()
# TODO initialize self._entries and self._aliases_table ?
# Should we initialize self._entries and self._aliases_table to specific starting size ?
self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
@ -89,10 +97,10 @@ cdef class KnowledgeBase:
def get_alias_strings(self):
return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float prob=0.5, vectors=None, features=None):
def add_entity(self, unicode entity, float prob, vector[float] entity_vector):
"""
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end
Return the hash of the entity ID/name at the end.
"""
cdef hash_t entity_hash = self.vocab.strings.add(entity)
@ -101,31 +109,41 @@ cdef class KnowledgeBase:
user_warning(Warnings.W018.format(entity=entity))
return
cdef int32_t dummy_value = 342
new_index = self.c_add_entity(entity_hash=entity_hash, prob=prob,
vector_rows=&dummy_value, feats_row=dummy_value)
self._entry_index[entity_hash] = new_index
if len(entity_vector) != self.entity_vector_length:
# TODO: proper error
raise ValueError("Entity vector length should have been", self.entity_vector_length)
# TODO self._vectors_table.get_pointer(vectors),
# self._features_table.get(features))
vector_index = self.c_add_vector(entity_vector=entity_vector)
new_index = self.c_add_entity(entity_hash=entity_hash,
prob=prob,
vector_index=vector_index,
feats_row=-1) # Features table currently not implemented
self._entry_index[entity_hash] = new_index
return entity_hash
cpdef set_entities(self, entity_list, prob_list, vector_list, feature_list):
cpdef set_entities(self, entity_list, prob_list, vector_list):
nr_entities = len(entity_list)
self._entry_index = PreshMap(nr_entities+1)
self._entries = entry_vec(nr_entities+1)
i = 0
cdef EntryC entry
cdef int32_t dummy_value = 342
while i < nr_entities:
# TODO features and vectors
entity_hash = self.vocab.strings.add(entity_list[i])
entity_vector = entity_list[i]
if len(entity_vector) != self.entity_vector_length:
# TODO: proper error
raise ValueError("Entity vector length should have been", self.entity_vector_length)
entity_hash = self.vocab.strings.add(entity_vector)
entry.entity_hash = entity_hash
entry.prob = prob_list[i]
entry.vector_rows = &dummy_value
entry.feats_row = dummy_value
vector_index = self.c_add_vector(entity_vector=vector_list[i])
entry.vector_index = vector_index
entry.feats_row = -1 # Features table currently not implemented
self._entries[i+1] = entry
self._entry_index[entity_hash] = i+1
@ -186,7 +204,7 @@ cdef class KnowledgeBase:
cdef hash_t alias_hash = self.vocab.strings.add(alias)
# Return if this alias was added before
# Check whether this alias was added before
if alias_hash in self._alias_index:
user_warning(Warnings.W017.format(alias=alias))
return
@ -208,9 +226,7 @@ cdef class KnowledgeBase:
return alias_hash
def get_candidates(self, unicode alias):
""" TODO: where to put this functionality ?"""
cdef hash_t alias_hash = self.vocab.strings[alias]
alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index]
@ -218,6 +234,7 @@ cdef class KnowledgeBase:
return [Candidate(kb=self,
entity_hash=self._entries[entry_index].entity_hash,
entity_freq=self._entries[entry_index].prob,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
alias_hash=alias_hash,
prior_prob=prob)
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
@ -226,16 +243,23 @@ cdef class KnowledgeBase:
def dump(self, loc):
cdef Writer writer = Writer(loc)
writer.write_header(self.get_size_entities())
writer.write_header(self.get_size_entities(), self.entity_vector_length)
# dumping the entity vectors in their original order
i = 0
for entity_vector in self._vectors_table:
for element in entity_vector:
writer.write_vector_element(element)
i = i+1
# dumping the entry records in the order in which they are in the _entries vector.
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
i = 1
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]):
entry = self._entries[entry_index]
assert entry.entity_hash == entry_hash
assert entry.entity_hash == entry_hash
assert entry_index == i
writer.write_entry(entry.entity_hash, entry.prob)
writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index)
i = i+1
writer.write_alias_length(self.get_size_aliases())
@ -262,31 +286,47 @@ cdef class KnowledgeBase:
cdef hash_t alias_hash
cdef int64_t entry_index
cdef float prob
cdef int32_t vector_index
cdef EntryC entry
cdef AliasC alias
cdef int32_t dummy_value = 342
cdef float vector_element
cdef Reader reader = Reader(loc)
# Step 1: load entities
# STEP 0: load header and initialize KB
cdef int64_t nr_entities
reader.read_header(&nr_entities)
cdef int64_t entity_vector_length
reader.read_header(&nr_entities, &entity_vector_length)
self.entity_vector_length = entity_vector_length
self._entry_index = PreshMap(nr_entities+1)
self._entries = entry_vec(nr_entities+1)
self._vectors_table = float_matrix(nr_entities+1)
# STEP 1: load entity vectors
cdef int i = 0
cdef int j = 0
while i < nr_entities:
entity_vector = float_vec(entity_vector_length)
j = 0
while j < entity_vector_length:
reader.read_vector_element(&vector_element)
entity_vector[j] = vector_element
j = j+1
self._vectors_table[i] = entity_vector
i = i+1
# STEP 2: load entities
# we assume that the entity data was written in sequence
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
# TODO: should we initialize the dummy objects ?
cdef int i = 1
i = 1
while i <= nr_entities:
reader.read_entry(&entity_hash, &prob)
reader.read_entry(&entity_hash, &prob, &vector_index)
# TODO features and vectors
entry.entity_hash = entity_hash
entry.prob = prob
entry.vector_rows = &dummy_value
entry.feats_row = dummy_value
entry.vector_index = vector_index
entry.feats_row = -1 # Features table currently not implemented
self._entries[i] = entry
self._entry_index[entity_hash] = i
@ -296,7 +336,8 @@ cdef class KnowledgeBase:
# check that all entities were read in properly
assert nr_entities == self.get_size_entities()
# Step 2: load aliases
# STEP 3: load aliases
cdef int64_t nr_aliases
reader.read_alias_length(&nr_aliases)
self._alias_index = PreshMap(nr_aliases+1)
@ -344,13 +385,18 @@ cdef class Writer:
cdef size_t status = fclose(self._fp)
assert status == 0
cdef int write_header(self, int64_t nr_entries) except -1:
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1:
self._write(&nr_entries, sizeof(nr_entries))
self._write(&entity_vector_length, sizeof(entity_vector_length))
cdef int write_entry(self, hash_t entry_hash, float entry_prob) except -1:
# TODO: feats_rows and vector rows
cdef int write_vector_element(self, float element) except -1:
self._write(&element, sizeof(element))
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1:
self._write(&entry_hash, sizeof(entry_hash))
self._write(&entry_prob, sizeof(entry_prob))
self._write(&vector_index, sizeof(vector_index))
# Features table currently not implemented and not written to file
cdef int write_alias_length(self, int64_t alias_length) except -1:
self._write(&alias_length, sizeof(alias_length))
@ -381,14 +427,27 @@ cdef class Reader:
def __dealloc__(self):
fclose(self._fp)
cdef int read_header(self, int64_t* nr_entries) except -1:
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1:
status = self._read(nr_entries, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading header from input file")
cdef int read_entry(self, hash_t* entity_hash, float* prob) except -1:
status = self._read(entity_vector_length, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading header from input file")
cdef int read_vector_element(self, float* element) except -1:
status = self._read(element, sizeof(float))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading entity vector from input file")
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1:
status = self._read(entity_hash, sizeof(hash_t))
if status < 1:
if feof(self._fp):
@ -401,6 +460,12 @@ cdef class Reader:
return 0 # end of file
raise IOError("error reading entity prob from input file")
status = self._read(vector_index, sizeof(int32_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading entity vector from input file")
if feof(self._fp):
return 0
else:

View File

@ -3,7 +3,7 @@
# coding: utf8
from __future__ import unicode_literals
cimport numpy as np
import numpy as np
import numpy
import srsly

View File

@ -84,16 +84,12 @@ cdef struct EntryC:
# The hash of this entry's unique ID/name in the kB
hash_t entity_hash
# Allows retrieval of one or more vectors.
# Each element of vector_rows should be an index into a vectors table.
# Every entry should have the same number of vectors, so we can avoid storing
# the number of vectors in each knowledge-base struct
int32_t* vector_rows
# Allows retrieval of the entity vector, as an index into a vectors table of the KB.
# Can be expanded later to refer to multiple rows (compositional model to reduce storage footprint).
int32_t vector_index
# Allows retrieval of a struct of non-vector features. We could make this a
# pointer, but we have 32 bits left over in the struct after prob, so we'd
# like this to only be 32 bits. We can also set this to -1, for the common
# case where there are no features.
# Allows retrieval of a struct of non-vector features.
# This is currently not implemented and set to -1 for the common case where there are no features.
int32_t feats_row
# log probability of entity, based on corpus frequency

View File

@ -14,12 +14,12 @@ def nlp():
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2')
mykb.add_entity(entity=u'Q3', prob=0.5)
mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity=u'Q2', prob=0.5, entity_vector=[2])
mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])
@ -32,12 +32,12 @@ def test_kb_valid_entities(nlp):
def test_kb_invalid_entities(nlp):
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
mykb = KnowledgeBase(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3])
# adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError):
@ -46,12 +46,12 @@ def test_kb_invalid_entities(nlp):
def test_kb_invalid_probabilities(nlp):
"""Test the invalid construction of a KB with wrong prior probabilities"""
mykb = KnowledgeBase(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3])
# adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError):
@ -60,26 +60,38 @@ def test_kb_invalid_probabilities(nlp):
def test_kb_invalid_combination(nlp):
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
mykb = KnowledgeBase(nlp.vocab)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3])
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError):
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.3, 0.4, 0.1])
def test_candidate_generation(nlp):
"""Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab)
def test_kb_invalid_entity_vector(nlp):
"""Test the invalid construction of a KB with non-matching entity vector lengths"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9)
mykb.add_entity(entity=u'Q2', prob=0.2)
mykb.add_entity(entity=u'Q3', prob=0.5)
mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1, 2, 3])
# this should fail because the kb's expected entity vector length is 3
with pytest.raises(ValueError):
mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2])
def test_candidate_generation(nlp):
"""Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])

View File

@ -20,7 +20,7 @@ def test_serialize_kb_disk(en_vocab):
print(file_path, type(file_path))
kb1.dump(str(file_path))
kb2 = KnowledgeBase(vocab=en_vocab)
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
kb2.load_bulk(str(file_path))
# final assertions
@ -28,12 +28,13 @@ def test_serialize_kb_disk(en_vocab):
def _get_dummy_kb(vocab):
kb = KnowledgeBase(vocab=vocab)
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
kb.add_entity(entity="Q53", prob=0.33, entity_vector=[0, 5, 3])
kb.add_entity(entity="Q17", prob=0.2, entity_vector=[7, 1, 0])
kb.add_entity(entity="Q007", prob=0.7, entity_vector=[0, 0, 7])
kb.add_entity(entity="Q44", prob=0.4, entity_vector=[4, 4, 4])
kb.add_entity(entity="Q53", prob=0.33)
kb.add_entity(entity="Q17", prob=0.2)
kb.add_entity(entity="Q007", prob=0.7)
kb.add_entity(entity="Q44", prob=0.4)
kb.add_alias(alias="double07", entities=["Q17", "Q007"], probabilities=[0.1, 0.9])
kb.add_alias(alias="guy", entities=["Q53", "Q007", "Q17", "Q44"], probabilities=[0.3, 0.3, 0.2, 0.1])
kb.add_alias(alias="random", entities=["Q007"], probabilities=[1.0])
@ -62,10 +63,12 @@ def _check_kb(kb):
assert candidates[0].entity_ == "Q007"
assert 0.6999 < candidates[0].entity_freq < 0.701
assert candidates[0].entity_vector == [0, 0, 7]
assert candidates[0].alias_ == "double07"
assert 0.899 < candidates[0].prior_prob < 0.901
assert candidates[1].entity_ == "Q17"
assert 0.199 < candidates[1].entity_freq < 0.201
assert candidates[1].entity_vector == [7, 1, 0]
assert candidates[1].alias_ == "double07"
assert 0.099 < candidates[1].prior_prob < 0.101