From e0e793be4d8146768e722c23d16cf7c5b170155e Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Tue, 22 Sep 2020 21:53:06 +0200 Subject: [PATCH] fix KB IO (#6118) --- spacy/kb.pxd | 1 - spacy/kb.pyx | 47 ++++++++++++---------- spacy/tests/pipeline/test_entity_linker.py | 23 +++++++++++ 3 files changed, 49 insertions(+), 22 deletions(-) diff --git a/spacy/kb.pxd b/spacy/kb.pxd index 695693666..4a71b26a2 100644 --- a/spacy/kb.pxd +++ b/spacy/kb.pxd @@ -140,7 +140,6 @@ cdef class KnowledgeBase: self._entries.push_back(entry) self._aliases_table.push_back(alias) - cpdef from_disk(self, loc) cpdef set_entities(self, entity_list, freq_list, vector_list) diff --git a/spacy/kb.pyx b/spacy/kb.pyx index b24ed3a20..ff5382c24 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -9,7 +9,8 @@ from libcpp.vector cimport vector from pathlib import Path import warnings -from os import path + +from spacy import util from .typedefs cimport hash_t from .errors import Errors, Warnings @@ -319,8 +320,14 @@ cdef class KnowledgeBase: return 0.0 - def to_disk(self, loc): - cdef Writer writer = Writer(loc) + def to_disk(self, path): + path = util.ensure_path(path) + if path.is_dir(): + raise ValueError(Errors.E928.format(loc=path)) + if not path.parent.exists(): + path.parent.mkdir(parents=True) + + cdef Writer writer = Writer(path) writer.write_header(self.get_size_entities(), self.entity_vector_length) # dumping the entity vectors in their original order @@ -359,7 +366,13 @@ cdef class KnowledgeBase: writer.close() - cpdef from_disk(self, loc): + def from_disk(self, path): + path = util.ensure_path(path) + if path.is_dir(): + raise ValueError(Errors.E928.format(loc=path)) + if not path.exists(): + raise ValueError(Errors.E929.format(loc=path)) + cdef hash_t entity_hash cdef hash_t alias_hash cdef int64_t entry_index @@ -369,7 +382,7 @@ cdef class KnowledgeBase: cdef AliasC alias cdef float vector_element - cdef Reader reader = Reader(loc) + cdef Reader reader = Reader(path) # STEP 0: load header and initialize KB cdef int64_t nr_entities @@ -450,16 +463,13 @@ cdef class KnowledgeBase: cdef class Writer: - def __init__(self, object loc): - if isinstance(loc, Path): - loc = bytes(loc) - if path.exists(loc): - if path.isdir(loc): - raise ValueError(Errors.E928.format(loc=loc)) - cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc + def __init__(self, path): + assert isinstance(path, Path) + content = bytes(path) + cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content self._fp = fopen(bytes_loc, 'wb') if not self._fp: - raise IOError(Errors.E146.format(path=loc)) + raise IOError(Errors.E146.format(path=path)) fseek(self._fp, 0, 0) def close(self): @@ -496,14 +506,9 @@ cdef class Writer: cdef class Reader: - def __init__(self, object loc): - if isinstance(loc, Path): - loc = bytes(loc) - if not path.exists(loc): - raise ValueError(Errors.E929.format(loc=loc)) - if path.isdir(loc): - raise ValueError(Errors.E928.format(loc=loc)) - cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc + def __init__(self, path): + content = bytes(path) + cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content self._fp = fopen(bytes_loc, 'rb') if not self._fp: PyErr_SetFromErrno(IOError) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index c43d2c58e..88e0646b3 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -144,6 +144,29 @@ def test_kb_empty(nlp): entity_linker.begin_training(lambda: []) +def test_kb_serialize(nlp): + """Test serialization of the KB""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + with make_tempdir() as d: + # normal read-write behaviour + mykb.to_disk(d / "kb") + mykb.from_disk(d / "kb") + mykb.to_disk(d / "kb.file") + mykb.from_disk(d / "kb.file") + mykb.to_disk(d / "new" / "kb") + mykb.from_disk(d / "new" / "kb") + # allow overwriting an existing file + mykb.to_disk(d / "kb.file") + with pytest.raises(ValueError): + # can not write to a directory + mykb.to_disk(d) + with pytest.raises(ValueError): + # can not read from a directory + mykb.from_disk(d) + with pytest.raises(ValueError): + # can not read from an unknown file + mykb.from_disk(d / "unknown" / "kb") + def test_candidate_generation(nlp): """Test correct candidate generation""" mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)