diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py index 674c6166c..8628c54a9 100644 --- a/examples/pipeline/wikidata_entity_linking.py +++ b/examples/pipeline/wikidata_entity_linking.py @@ -432,6 +432,7 @@ if __name__ == "__main__": kb1.add_entity(entity="Q007", prob=0.7) kb1.add_entity(entity="Q44", prob=0.4) print("kb1 size:", len(kb1), kb1.get_size_entities(), kb1.get_size_aliases()) + print("dumping kb1") kb1.dump(KB_FILE) @@ -439,7 +440,10 @@ if __name__ == "__main__": nlp3 = spacy.load('en_core_web_sm') kb3 = KnowledgeBase(vocab=nlp3.vocab) - kb3.load_bulk(7, KB_FILE) + + kb3.load_bulk(KB_FILE) + + print("loading kb3") print("kb3 size:", len(kb3), kb3.get_size_entities(), kb3.get_size_aliases()) # STEP 5 : actually use the EL functionality diff --git a/spacy/kb.pxd b/spacy/kb.pxd index 817b7ff25..9c393e5f2 100644 --- a/spacy/kb.pxd +++ b/spacy/kb.pxd @@ -129,16 +129,20 @@ cdef class KnowledgeBase: self._entries.push_back(entry) self._aliases_table.push_back(alias) - cpdef load_bulk(self, int nr_entities, loc) + cpdef load_bulk(self, loc) cdef class Writer: cdef FILE* _fp - cdef int write(self, int64_t entry_id, hash_t entity_hash, float prob) except -1 + cdef int write_header(self, int64_t nr_entries) except -1 + cdef int write_entry(self, int64_t entry_id, hash_t entry_hash, float entry_prob) except -1 + cdef int _write(self, void* value, size_t size) except -1 cdef class Reader: cdef FILE* _fp - cdef int read(self, Pool mem, int64_t* entry_id, hash_t* entity_hash, float* prob) except -1 + cdef int read_header(self, int64_t* nr_entries) except -1 + cdef int read_entry(self, int64_t* entry_id, hash_t* entity_hash, float* prob) except -1 + cdef int _read(self, void* value, size_t size) except -1 diff --git a/spacy/kb.pyx b/spacy/kb.pyx index c967654d3..21c6d9049 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -64,6 +64,8 @@ cdef class KnowledgeBase: self._entry_index = PreshMap() self._alias_index = PreshMap() + # TODO initialize self._entries and self._aliases_table ? + self.vocab.strings.add("") self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) @@ -162,26 +164,21 @@ cdef class KnowledgeBase: def dump(self, loc): cdef Writer writer = Writer(loc) + writer.write_header(self.get_size_entities()) # dumping the entry records in the order in which they are in the _entries vector. # index 0 is a dummy object not stored in the _entry_index and can be ignored. i = 1 for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): entry = self._entries[entry_index] - print("dumping") - print("index", entry_index) - print("hash", entry.entity_hash) assert entry.entity_hash == entry_hash assert entry_index == i - print("prob", entry.prob) - print("") - writer.write(entry_index, entry.entity_hash, entry.prob) + writer.write_entry(entry_index, entry.entity_hash, entry.prob) i = i+1 writer.close() - cpdef load_bulk(self, int nr_entities, loc): - # TODO: nr_entities from header in file (Reader constructor) + cpdef load_bulk(self, loc): cdef int64_t entry_id cdef hash_t entity_hash cdef float prob @@ -189,7 +186,8 @@ cdef class KnowledgeBase: cdef int32_t dummy_value = 342 cdef Reader reader = Reader(loc) - to_read = self.get_size_entities() + cdef int64_t nr_entities + reader.read_header(&nr_entities) self._entry_index = PreshMap(nr_entities+1) self._entries = entry_vec(nr_entities+1) @@ -198,23 +196,15 @@ cdef class KnowledgeBase: # index 0 is a dummy object not stored in the _entry_index and can be ignored. # TODO: should we initialize the dummy objects ? cdef int i = 1 - while reader.read(self.mem, &entry_id, &entity_hash, &prob) and i <= nr_entities: + while reader.read_entry(&entry_id, &entity_hash, &prob) and i <= nr_entities: assert i == entry_id + # TODO features and vectors entry.entity_hash = entity_hash entry.prob = prob - - # TODO features and vectors entry.vector_rows = &dummy_value entry.feats_row = dummy_value - print("bulk loading") - print("i", i) - print("entryID", entry_id) - print("hash", entry.entity_hash) - print("prob", entry.prob) - print("") - self._entries[i] = entry self._entry_index[entity_hash] = i @@ -234,16 +224,18 @@ cdef class Writer: cdef size_t status = fclose(self._fp) assert status == 0 - cdef int write(self, int64_t entry_id, hash_t entry_hash, float entry_prob) except -1: + cdef int write_header(self, int64_t nr_entries) except -1: + self._write(&nr_entries, sizeof(nr_entries)) + + cdef int write_entry(self, int64_t entry_id, hash_t entry_hash, float entry_prob) except -1: # TODO: feats_rows and vector rows - _write(&entry_id, sizeof(entry_id), self._fp) - _write(&entry_hash, sizeof(entry_hash), self._fp) - _write(&entry_prob, sizeof(entry_prob), self._fp) + self._write(&entry_id, sizeof(entry_id)) + self._write(&entry_hash, sizeof(entry_hash)) + self._write(&entry_prob, sizeof(entry_prob)) - -cdef int _write(void* value, size_t size, FILE* fp) except -1: - status = fwrite(value, size, 1, fp) - assert status == 1, status + cdef int _write(self, void* value, size_t size) except -1: + status = fwrite(value, size, 1, self._fp) + assert status == 1, status cdef class Reader: @@ -259,20 +251,27 @@ cdef class Reader: def __dealloc__(self): fclose(self._fp) - cdef int read(self, Pool mem, int64_t* entry_id, hash_t* entity_hash, float* prob) except -1: - status = fread(entry_id, sizeof(int64_t), 1, self._fp) + cdef int read_header(self, int64_t* nr_entries) except -1: + status = self._read(nr_entries, sizeof(int64_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading header from input file") + + cdef int read_entry(self, int64_t* entry_id, hash_t* entity_hash, float* prob) except -1: + status = self._read(entry_id, sizeof(int64_t)) if status < 1: if feof(self._fp): return 0 # end of file raise IOError("error reading entry ID from input file") - status = fread(entity_hash, sizeof(hash_t), 1, self._fp) + status = self._read(entity_hash, sizeof(hash_t)) if status < 1: if feof(self._fp): return 0 # end of file raise IOError("error reading entity hash from input file") - status = fread(prob, sizeof(float), 1, self._fp) + status = self._read(prob, sizeof(float)) if status < 1: if feof(self._fp): return 0 # end of file @@ -283,4 +282,8 @@ cdef class Reader: else: return 1 + cdef int _read(self, void* value, size_t size) except -1: + status = fread(value, size, 1, self._fp) + return status +