diff --git a/examples/pipeline/dummy_entity_linking.py b/examples/pipeline/dummy_entity_linking.py index ae36a57b3..3f1fabdfd 100644 --- a/examples/pipeline/dummy_entity_linking.py +++ b/examples/pipeline/dummy_entity_linking.py @@ -9,20 +9,20 @@ from spacy.kb import KnowledgeBase def create_kb(vocab): - kb = KnowledgeBase(vocab=vocab) + kb = KnowledgeBase(vocab=vocab, entity_vector_length=1) # adding entities entity_0 = "Q1004791_Douglas" print("adding entity", entity_0) - kb.add_entity(entity=entity_0, prob=0.5) + kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0]) entity_1 = "Q42_Douglas_Adams" print("adding entity", entity_1) - kb.add_entity(entity=entity_1, prob=0.5) + kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1]) entity_2 = "Q5301561_Douglas_Haig" print("adding entity", entity_2) - kb.add_entity(entity=entity_2, prob=0.5) + kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2]) # adding aliases print() diff --git a/examples/pipeline/wiki_entity_linking/kb_creator.py b/examples/pipeline/wiki_entity_linking/kb_creator.py index bb00f918d..ae3422c91 100644 --- a/examples/pipeline/wiki_entity_linking/kb_creator.py +++ b/examples/pipeline/wiki_entity_linking/kb_creator.py @@ -16,7 +16,7 @@ def create_kb(vocab, max_entities_per_alias, min_occ, count_input, prior_prob_input, to_print=False, write_entity_defs=True): """ Create the knowledge base from Wikidata entries """ - kb = KnowledgeBase(vocab=vocab) + kb = KnowledgeBase(vocab=vocab, entity_vector_length=64) # TODO: entity vectors ! print() print("1. _read_wikidata_entities", datetime.datetime.now()) @@ -38,7 +38,8 @@ def create_kb(vocab, max_entities_per_alias, min_occ, print() print("3. adding", len(entity_list), "entities", datetime.datetime.now()) print() - kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=None, feature_list=None) + # TODO: vector_list ! + kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=None) print() print("4. adding aliases", datetime.datetime.now()) diff --git a/examples/pipeline/wiki_entity_linking/train_descriptions.py b/examples/pipeline/wiki_entity_linking/train_descriptions.py index 63149b5f7..88b1bf819 100644 --- a/examples/pipeline/wiki_entity_linking/train_descriptions.py +++ b/examples/pipeline/wiki_entity_linking/train_descriptions.py @@ -19,7 +19,7 @@ class EntityEncoder: DROP = 0 EPOCHS = 5 - STOP_THRESHOLD = 0.05 + STOP_THRESHOLD = 0.1 BATCH_SIZE = 1000 @@ -38,6 +38,8 @@ class EntityEncoder: # TODO: apply and write to file afterwards ! # self._apply_encoder(id_to_descr) + self._test_encoder() + def _train_model(self, entity_descr_output, id_to_descr): # TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy @@ -111,3 +113,40 @@ class EntityEncoder: def get_loss(golds, scores): loss, gradients = get_cossim_loss(scores, golds) return loss, gradients + + def _test_encoder(self): + """ Test encoder on some dummy examples """ + desc_A1 = "Fictional character in The Simpsons" + desc_A2 = "Simpsons - fictional human" + desc_A3 = "Fictional character in The Flintstones" + desc_A4 = "Politician from the US" + + A1_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A1))]) + A2_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A2))]) + A3_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A3))]) + A4_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A4))]) + + loss_a1_a1, _ = get_cossim_loss(A1_doc_vector, A1_doc_vector) + loss_a1_a2, _ = get_cossim_loss(A1_doc_vector, A2_doc_vector) + loss_a1_a3, _ = get_cossim_loss(A1_doc_vector, A3_doc_vector) + loss_a1_a4, _ = get_cossim_loss(A1_doc_vector, A4_doc_vector) + + print("sim doc A1 A1", loss_a1_a1) + print("sim doc A1 A2", loss_a1_a2) + print("sim doc A1 A3", loss_a1_a3) + print("sim doc A1 A4", loss_a1_a4) + + A1_encoded = self.encoder(A1_doc_vector) + A2_encoded = self.encoder(A2_doc_vector) + A3_encoded = self.encoder(A3_doc_vector) + A4_encoded = self.encoder(A4_doc_vector) + + loss_a1_a1, _ = get_cossim_loss(A1_encoded, A1_encoded) + loss_a1_a2, _ = get_cossim_loss(A1_encoded, A2_encoded) + loss_a1_a3, _ = get_cossim_loss(A1_encoded, A3_encoded) + loss_a1_a4, _ = get_cossim_loss(A1_encoded, A4_encoded) + + print("sim encoded A1 A1", loss_a1_a1) + print("sim encoded A1 A2", loss_a1_a2) + print("sim encoded A1 A3", loss_a1_a3) + print("sim encoded A1 A4", loss_a1_a4) diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 1f4b4b67e..d813238b7 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -93,7 +93,7 @@ if __name__ == "__main__": print("STEP 4: to_read_kb", datetime.datetime.now()) my_vocab = Vocab() my_vocab.from_disk(VOCAB_DIR) - my_kb = KnowledgeBase(vocab=my_vocab) + my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64) # TODO entity vectors my_kb.load_bulk(KB_FILE) print("kb entities:", my_kb.get_size_entities()) print("kb aliases:", my_kb.get_size_aliases()) diff --git a/spacy/kb.pxd b/spacy/kb.pxd index 494848e5e..9c5a73d59 100644 --- a/spacy/kb.pxd +++ b/spacy/kb.pxd @@ -12,6 +12,8 @@ from .typedefs cimport hash_t from .structs cimport EntryC, AliasC ctypedef vector[EntryC] entry_vec ctypedef vector[AliasC] alias_vec +ctypedef vector[float] float_vec +ctypedef vector[float_vec] float_matrix # Object used by the Entity Linker that summarizes one entity-alias candidate combination. @@ -20,6 +22,7 @@ cdef class Candidate: cdef readonly KnowledgeBase kb cdef hash_t entity_hash cdef float entity_freq + cdef vector[float] entity_vector cdef hash_t alias_hash cdef float prior_prob @@ -27,6 +30,7 @@ cdef class Candidate: cdef class KnowledgeBase: cdef Pool mem cpdef readonly Vocab vocab + cdef int64_t entity_vector_length # This maps 64bit keys (hash of unique entity string) # to 64bit values (position of the _EntryC struct in the _entries vector). @@ -59,7 +63,7 @@ cdef class KnowledgeBase: # model, that embeds different features of the entities into vectors. We'll # still want some per-entity features, like the Wikipedia text or entity # co-occurrence. Hopefully those vectors can be narrow, e.g. 64 dimensions. - cdef object _vectors_table + cdef float_matrix _vectors_table # It's very useful to track categorical features, at least for output, even # if they're not useful in the model itself. For instance, we should be @@ -69,8 +73,15 @@ cdef class KnowledgeBase: cdef object _features_table + cdef inline int64_t c_add_vector(self, vector[float] entity_vector) nogil: + """Add an entity vector to the vectors table.""" + cdef int64_t new_index = self._vectors_table.size() + self._vectors_table.push_back(entity_vector) + return new_index + + cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob, - int32_t* vector_rows, int feats_row) nogil: + int32_t vector_index, int feats_row) nogil: """Add an entry to the vector of entries. After calling this method, make sure to update also the _entry_index using the return value""" # This is what we'll map the entity hash key to. It's where the entry will sit @@ -80,7 +91,7 @@ cdef class KnowledgeBase: # Avoid struct initializer to enable nogil, cf https://github.com/cython/cython/issues/1642 cdef EntryC entry entry.entity_hash = entity_hash - entry.vector_rows = vector_rows + entry.vector_index = vector_index entry.feats_row = feats_row entry.prob = prob @@ -113,7 +124,7 @@ cdef class KnowledgeBase: # Avoid struct initializer to enable nogil cdef EntryC entry entry.entity_hash = dummy_hash - entry.vector_rows = &dummy_value + entry.vector_index = dummy_value entry.feats_row = dummy_value entry.prob = dummy_value @@ -131,15 +142,16 @@ cdef class KnowledgeBase: self._aliases_table.push_back(alias) cpdef load_bulk(self, loc) - cpdef set_entities(self, entity_list, prob_list, vector_list, feature_list) + cpdef set_entities(self, entity_list, prob_list, vector_list) cpdef set_aliases(self, alias_list, entities_list, probabilities_list) cdef class Writer: cdef FILE* _fp - cdef int write_header(self, int64_t nr_entries) except -1 - cdef int write_entry(self, hash_t entry_hash, float entry_prob) except -1 + cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1 + cdef int write_vector_element(self, float element) except -1 + cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1 cdef int write_alias_length(self, int64_t alias_length) except -1 cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1 @@ -150,8 +162,9 @@ cdef class Writer: cdef class Reader: cdef FILE* _fp - cdef int read_header(self, int64_t* nr_entries) except -1 - cdef int read_entry(self, hash_t* entity_hash, float* prob) except -1 + cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1 + cdef int read_vector_element(self, float* element) except -1 + cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1 cdef int read_alias_length(self, int64_t* alias_length) except -1 cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1 diff --git a/spacy/kb.pyx b/spacy/kb.pyx index d471130d0..790bb4992 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -26,10 +26,11 @@ from libcpp.vector cimport vector cdef class Candidate: - def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, alias_hash, prior_prob): + def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): self.kb = kb self.entity_hash = entity_hash self.entity_freq = entity_freq + self.entity_vector = entity_vector self.alias_hash = alias_hash self.prior_prob = prior_prob @@ -57,19 +58,26 @@ cdef class Candidate: def entity_freq(self): return self.entity_freq + @property + def entity_vector(self): + return self.entity_vector + @property def prior_prob(self): return self.prior_prob cdef class KnowledgeBase: - def __init__(self, Vocab vocab): + + def __init__(self, Vocab vocab, entity_vector_length): self.vocab = vocab self.mem = Pool() + self.entity_vector_length = entity_vector_length + self._entry_index = PreshMap() self._alias_index = PreshMap() - # TODO initialize self._entries and self._aliases_table ? + # Should we initialize self._entries and self._aliases_table to specific starting size ? self.vocab.strings.add("") self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) @@ -89,10 +97,10 @@ cdef class KnowledgeBase: def get_alias_strings(self): return [self.vocab.strings[x] for x in self._alias_index] - def add_entity(self, unicode entity, float prob=0.5, vectors=None, features=None): + def add_entity(self, unicode entity, float prob, vector[float] entity_vector): """ Add an entity to the KB, optionally specifying its log probability based on corpus frequency - Return the hash of the entity ID/name at the end + Return the hash of the entity ID/name at the end. """ cdef hash_t entity_hash = self.vocab.strings.add(entity) @@ -101,31 +109,41 @@ cdef class KnowledgeBase: user_warning(Warnings.W018.format(entity=entity)) return - cdef int32_t dummy_value = 342 - new_index = self.c_add_entity(entity_hash=entity_hash, prob=prob, - vector_rows=&dummy_value, feats_row=dummy_value) - self._entry_index[entity_hash] = new_index + if len(entity_vector) != self.entity_vector_length: + # TODO: proper error + raise ValueError("Entity vector length should have been", self.entity_vector_length) - # TODO self._vectors_table.get_pointer(vectors), - # self._features_table.get(features)) + vector_index = self.c_add_vector(entity_vector=entity_vector) + + new_index = self.c_add_entity(entity_hash=entity_hash, + prob=prob, + vector_index=vector_index, + feats_row=-1) # Features table currently not implemented + self._entry_index[entity_hash] = new_index return entity_hash - cpdef set_entities(self, entity_list, prob_list, vector_list, feature_list): + cpdef set_entities(self, entity_list, prob_list, vector_list): nr_entities = len(entity_list) self._entry_index = PreshMap(nr_entities+1) self._entries = entry_vec(nr_entities+1) i = 0 cdef EntryC entry - cdef int32_t dummy_value = 342 while i < nr_entities: - # TODO features and vectors - entity_hash = self.vocab.strings.add(entity_list[i]) + entity_vector = entity_list[i] + if len(entity_vector) != self.entity_vector_length: + # TODO: proper error + raise ValueError("Entity vector length should have been", self.entity_vector_length) + + entity_hash = self.vocab.strings.add(entity_vector) entry.entity_hash = entity_hash entry.prob = prob_list[i] - entry.vector_rows = &dummy_value - entry.feats_row = dummy_value + + vector_index = self.c_add_vector(entity_vector=vector_list[i]) + entry.vector_index = vector_index + + entry.feats_row = -1 # Features table currently not implemented self._entries[i+1] = entry self._entry_index[entity_hash] = i+1 @@ -186,7 +204,7 @@ cdef class KnowledgeBase: cdef hash_t alias_hash = self.vocab.strings.add(alias) - # Return if this alias was added before + # Check whether this alias was added before if alias_hash in self._alias_index: user_warning(Warnings.W017.format(alias=alias)) return @@ -208,9 +226,7 @@ cdef class KnowledgeBase: return alias_hash - def get_candidates(self, unicode alias): - """ TODO: where to put this functionality ?""" cdef hash_t alias_hash = self.vocab.strings[alias] alias_index = self._alias_index.get(alias_hash) alias_entry = self._aliases_table[alias_index] @@ -218,6 +234,7 @@ cdef class KnowledgeBase: return [Candidate(kb=self, entity_hash=self._entries[entry_index].entity_hash, entity_freq=self._entries[entry_index].prob, + entity_vector=self._vectors_table[self._entries[entry_index].vector_index], alias_hash=alias_hash, prior_prob=prob) for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs) @@ -226,16 +243,23 @@ cdef class KnowledgeBase: def dump(self, loc): cdef Writer writer = Writer(loc) - writer.write_header(self.get_size_entities()) + writer.write_header(self.get_size_entities(), self.entity_vector_length) + + # dumping the entity vectors in their original order + i = 0 + for entity_vector in self._vectors_table: + for element in entity_vector: + writer.write_vector_element(element) + i = i+1 # dumping the entry records in the order in which they are in the _entries vector. # index 0 is a dummy object not stored in the _entry_index and can be ignored. i = 1 for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): entry = self._entries[entry_index] - assert entry.entity_hash == entry_hash + assert entry.entity_hash == entry_hash assert entry_index == i - writer.write_entry(entry.entity_hash, entry.prob) + writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index) i = i+1 writer.write_alias_length(self.get_size_aliases()) @@ -262,31 +286,47 @@ cdef class KnowledgeBase: cdef hash_t alias_hash cdef int64_t entry_index cdef float prob + cdef int32_t vector_index cdef EntryC entry cdef AliasC alias - cdef int32_t dummy_value = 342 + cdef float vector_element cdef Reader reader = Reader(loc) - # Step 1: load entities - + # STEP 0: load header and initialize KB cdef int64_t nr_entities - reader.read_header(&nr_entities) + cdef int64_t entity_vector_length + reader.read_header(&nr_entities, &entity_vector_length) + + self.entity_vector_length = entity_vector_length self._entry_index = PreshMap(nr_entities+1) self._entries = entry_vec(nr_entities+1) + self._vectors_table = float_matrix(nr_entities+1) + # STEP 1: load entity vectors + cdef int i = 0 + cdef int j = 0 + while i < nr_entities: + entity_vector = float_vec(entity_vector_length) + j = 0 + while j < entity_vector_length: + reader.read_vector_element(&vector_element) + entity_vector[j] = vector_element + j = j+1 + self._vectors_table[i] = entity_vector + i = i+1 + + # STEP 2: load entities # we assume that the entity data was written in sequence # index 0 is a dummy object not stored in the _entry_index and can be ignored. - # TODO: should we initialize the dummy objects ? - cdef int i = 1 + i = 1 while i <= nr_entities: - reader.read_entry(&entity_hash, &prob) + reader.read_entry(&entity_hash, &prob, &vector_index) - # TODO features and vectors entry.entity_hash = entity_hash entry.prob = prob - entry.vector_rows = &dummy_value - entry.feats_row = dummy_value + entry.vector_index = vector_index + entry.feats_row = -1 # Features table currently not implemented self._entries[i] = entry self._entry_index[entity_hash] = i @@ -296,7 +336,8 @@ cdef class KnowledgeBase: # check that all entities were read in properly assert nr_entities == self.get_size_entities() - # Step 2: load aliases + # STEP 3: load aliases + cdef int64_t nr_aliases reader.read_alias_length(&nr_aliases) self._alias_index = PreshMap(nr_aliases+1) @@ -344,13 +385,18 @@ cdef class Writer: cdef size_t status = fclose(self._fp) assert status == 0 - cdef int write_header(self, int64_t nr_entries) except -1: + cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1: self._write(&nr_entries, sizeof(nr_entries)) + self._write(&entity_vector_length, sizeof(entity_vector_length)) - cdef int write_entry(self, hash_t entry_hash, float entry_prob) except -1: - # TODO: feats_rows and vector rows + cdef int write_vector_element(self, float element) except -1: + self._write(&element, sizeof(element)) + + cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1: self._write(&entry_hash, sizeof(entry_hash)) self._write(&entry_prob, sizeof(entry_prob)) + self._write(&vector_index, sizeof(vector_index)) + # Features table currently not implemented and not written to file cdef int write_alias_length(self, int64_t alias_length) except -1: self._write(&alias_length, sizeof(alias_length)) @@ -381,14 +427,27 @@ cdef class Reader: def __dealloc__(self): fclose(self._fp) - cdef int read_header(self, int64_t* nr_entries) except -1: + cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1: status = self._read(nr_entries, sizeof(int64_t)) if status < 1: if feof(self._fp): return 0 # end of file raise IOError("error reading header from input file") - cdef int read_entry(self, hash_t* entity_hash, float* prob) except -1: + status = self._read(entity_vector_length, sizeof(int64_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading header from input file") + + cdef int read_vector_element(self, float* element) except -1: + status = self._read(element, sizeof(float)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading entity vector from input file") + + cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1: status = self._read(entity_hash, sizeof(hash_t)) if status < 1: if feof(self._fp): @@ -401,6 +460,12 @@ cdef class Reader: return 0 # end of file raise IOError("error reading entity prob from input file") + status = self._read(vector_index, sizeof(int32_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading entity vector from input file") + if feof(self._fp): return 0 else: diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index d0c83b56e..d9fbe59ff 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -3,7 +3,7 @@ # coding: utf8 from __future__ import unicode_literals -cimport numpy as np +import numpy as np import numpy import srsly diff --git a/spacy/structs.pxd b/spacy/structs.pxd index 69a1f4961..8de4d5f4c 100644 --- a/spacy/structs.pxd +++ b/spacy/structs.pxd @@ -84,16 +84,12 @@ cdef struct EntryC: # The hash of this entry's unique ID/name in the kB hash_t entity_hash - # Allows retrieval of one or more vectors. - # Each element of vector_rows should be an index into a vectors table. - # Every entry should have the same number of vectors, so we can avoid storing - # the number of vectors in each knowledge-base struct - int32_t* vector_rows + # Allows retrieval of the entity vector, as an index into a vectors table of the KB. + # Can be expanded later to refer to multiple rows (compositional model to reduce storage footprint). + int32_t vector_index - # Allows retrieval of a struct of non-vector features. We could make this a - # pointer, but we have 32 bits left over in the struct after prob, so we'd - # like this to only be 32 bits. We can also set this to -1, for the common - # case where there are no features. + # Allows retrieval of a struct of non-vector features. + # This is currently not implemented and set to -1 for the common case where there are no features. int32_t feats_row # log probability of entity, based on corpus frequency diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 61baece68..b44332df4 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -14,12 +14,12 @@ def nlp(): def test_kb_valid_entities(nlp): """Test the valid construction of a KB with 3 entities and two aliases""" - mykb = KnowledgeBase(nlp.vocab) + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2') - mykb.add_entity(entity=u'Q3', prob=0.5) + mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity=u'Q2', prob=0.5, entity_vector=[2]) + mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3]) # adding aliases mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2]) @@ -32,12 +32,12 @@ def test_kb_valid_entities(nlp): def test_kb_invalid_entities(nlp): """Test the invalid construction of a KB with an alias linked to a non-existing entity""" - mykb = KnowledgeBase(nlp.vocab) + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) + mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3]) # adding aliases - should fail because one of the given IDs is not valid with pytest.raises(ValueError): @@ -46,12 +46,12 @@ def test_kb_invalid_entities(nlp): def test_kb_invalid_probabilities(nlp): """Test the invalid construction of a KB with wrong prior probabilities""" - mykb = KnowledgeBase(nlp.vocab) + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) + mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3]) # adding aliases - should fail because the sum of the probabilities exceeds 1 with pytest.raises(ValueError): @@ -60,26 +60,38 @@ def test_kb_invalid_probabilities(nlp): def test_kb_invalid_combination(nlp): """Test the invalid construction of a KB with non-matching entity and probability lists""" - mykb = KnowledgeBase(nlp.vocab) + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) + mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3]) # adding aliases - should fail because the entities and probabilities vectors are not of equal length with pytest.raises(ValueError): mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.3, 0.4, 0.1]) -def test_candidate_generation(nlp): - """Test correct candidate generation""" - mykb = KnowledgeBase(nlp.vocab) +def test_kb_invalid_entity_vector(nlp): + """Test the invalid construction of a KB with non-matching entity vector lengths""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) + mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1, 2, 3]) + + # this should fail because the kb's expected entity vector length is 3 + with pytest.raises(ValueError): + mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2]) + + +def test_candidate_generation(nlp): + """Test correct candidate generation""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity=u'Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity=u'Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity=u'Q3', prob=0.5, entity_vector=[3]) # adding aliases mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2]) diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py index 7b1380623..7a8022890 100644 --- a/spacy/tests/serialize/test_serialize_kb.py +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -20,7 +20,7 @@ def test_serialize_kb_disk(en_vocab): print(file_path, type(file_path)) kb1.dump(str(file_path)) - kb2 = KnowledgeBase(vocab=en_vocab) + kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3) kb2.load_bulk(str(file_path)) # final assertions @@ -28,12 +28,13 @@ def test_serialize_kb_disk(en_vocab): def _get_dummy_kb(vocab): - kb = KnowledgeBase(vocab=vocab) + kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) + + kb.add_entity(entity="Q53", prob=0.33, entity_vector=[0, 5, 3]) + kb.add_entity(entity="Q17", prob=0.2, entity_vector=[7, 1, 0]) + kb.add_entity(entity="Q007", prob=0.7, entity_vector=[0, 0, 7]) + kb.add_entity(entity="Q44", prob=0.4, entity_vector=[4, 4, 4]) - kb.add_entity(entity="Q53", prob=0.33) - kb.add_entity(entity="Q17", prob=0.2) - kb.add_entity(entity="Q007", prob=0.7) - kb.add_entity(entity="Q44", prob=0.4) kb.add_alias(alias="double07", entities=["Q17", "Q007"], probabilities=[0.1, 0.9]) kb.add_alias(alias="guy", entities=["Q53", "Q007", "Q17", "Q44"], probabilities=[0.3, 0.3, 0.2, 0.1]) kb.add_alias(alias="random", entities=["Q007"], probabilities=[1.0]) @@ -62,10 +63,12 @@ def _check_kb(kb): assert candidates[0].entity_ == "Q007" assert 0.6999 < candidates[0].entity_freq < 0.701 + assert candidates[0].entity_vector == [0, 0, 7] assert candidates[0].alias_ == "double07" assert 0.899 < candidates[0].prior_prob < 0.901 assert candidates[1].entity_ == "Q17" assert 0.199 < candidates[1].entity_freq < 0.201 + assert candidates[1].entity_vector == [7, 1, 0] assert candidates[1].alias_ == "double07" assert 0.099 < candidates[1].prior_prob < 0.101