This commit is contained in:
Sofie Van Landeghem 2020-09-22 21:53:06 +02:00 committed by GitHub
parent d53c84b6d6
commit e0e793be4d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 49 additions and 22 deletions

View File

@ -140,7 +140,6 @@ cdef class KnowledgeBase:
self._entries.push_back(entry) self._entries.push_back(entry)
self._aliases_table.push_back(alias) self._aliases_table.push_back(alias)
cpdef from_disk(self, loc)
cpdef set_entities(self, entity_list, freq_list, vector_list) cpdef set_entities(self, entity_list, freq_list, vector_list)

View File

@ -9,7 +9,8 @@ from libcpp.vector cimport vector
from pathlib import Path from pathlib import Path
import warnings import warnings
from os import path
from spacy import util
from .typedefs cimport hash_t from .typedefs cimport hash_t
from .errors import Errors, Warnings from .errors import Errors, Warnings
@ -319,8 +320,14 @@ cdef class KnowledgeBase:
return 0.0 return 0.0
def to_disk(self, loc): def to_disk(self, path):
cdef Writer writer = Writer(loc) path = util.ensure_path(path)
if path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
if not path.parent.exists():
path.parent.mkdir(parents=True)
cdef Writer writer = Writer(path)
writer.write_header(self.get_size_entities(), self.entity_vector_length) writer.write_header(self.get_size_entities(), self.entity_vector_length)
# dumping the entity vectors in their original order # dumping the entity vectors in their original order
@ -359,7 +366,13 @@ cdef class KnowledgeBase:
writer.close() writer.close()
cpdef from_disk(self, loc): def from_disk(self, path):
path = util.ensure_path(path)
if path.is_dir():
raise ValueError(Errors.E928.format(loc=path))
if not path.exists():
raise ValueError(Errors.E929.format(loc=path))
cdef hash_t entity_hash cdef hash_t entity_hash
cdef hash_t alias_hash cdef hash_t alias_hash
cdef int64_t entry_index cdef int64_t entry_index
@ -369,7 +382,7 @@ cdef class KnowledgeBase:
cdef AliasC alias cdef AliasC alias
cdef float vector_element cdef float vector_element
cdef Reader reader = Reader(loc) cdef Reader reader = Reader(path)
# STEP 0: load header and initialize KB # STEP 0: load header and initialize KB
cdef int64_t nr_entities cdef int64_t nr_entities
@ -450,16 +463,13 @@ cdef class KnowledgeBase:
cdef class Writer: cdef class Writer:
def __init__(self, object loc): def __init__(self, path):
if isinstance(loc, Path): assert isinstance(path, Path)
loc = bytes(loc) content = bytes(path)
if path.exists(loc): cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content
if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb') self._fp = fopen(<char*>bytes_loc, 'wb')
if not self._fp: if not self._fp:
raise IOError(Errors.E146.format(path=loc)) raise IOError(Errors.E146.format(path=path))
fseek(self._fp, 0, 0) fseek(self._fp, 0, 0)
def close(self): def close(self):
@ -496,14 +506,9 @@ cdef class Writer:
cdef class Reader: cdef class Reader:
def __init__(self, object loc): def __init__(self, path):
if isinstance(loc, Path): content = bytes(path)
loc = bytes(loc) cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content
if not path.exists(loc):
raise ValueError(Errors.E929.format(loc=loc))
if path.isdir(loc):
raise ValueError(Errors.E928.format(loc=loc))
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'rb') self._fp = fopen(<char*>bytes_loc, 'rb')
if not self._fp: if not self._fp:
PyErr_SetFromErrno(IOError) PyErr_SetFromErrno(IOError)

View File

@ -144,6 +144,29 @@ def test_kb_empty(nlp):
entity_linker.begin_training(lambda: []) entity_linker.begin_training(lambda: [])
def test_kb_serialize(nlp):
"""Test serialization of the KB"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
with make_tempdir() as d:
# normal read-write behaviour
mykb.to_disk(d / "kb")
mykb.from_disk(d / "kb")
mykb.to_disk(d / "kb.file")
mykb.from_disk(d / "kb.file")
mykb.to_disk(d / "new" / "kb")
mykb.from_disk(d / "new" / "kb")
# allow overwriting an existing file
mykb.to_disk(d / "kb.file")
with pytest.raises(ValueError):
# can not write to a directory
mykb.to_disk(d)
with pytest.raises(ValueError):
# can not read from a directory
mykb.from_disk(d)
with pytest.raises(ValueError):
# can not read from an unknown file
mykb.from_disk(d / "unknown" / "kb")
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
"""Test correct candidate generation""" """Test correct candidate generation"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)