mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
fix KB IO (#6118)
This commit is contained in:
parent
d53c84b6d6
commit
e0e793be4d
|
@ -140,7 +140,6 @@ cdef class KnowledgeBase:
|
|||
self._entries.push_back(entry)
|
||||
self._aliases_table.push_back(alias)
|
||||
|
||||
cpdef from_disk(self, loc)
|
||||
cpdef set_entities(self, entity_list, freq_list, vector_list)
|
||||
|
||||
|
||||
|
|
47
spacy/kb.pyx
47
spacy/kb.pyx
|
@ -9,7 +9,8 @@ from libcpp.vector cimport vector
|
|||
|
||||
from pathlib import Path
|
||||
import warnings
|
||||
from os import path
|
||||
|
||||
from spacy import util
|
||||
|
||||
from .typedefs cimport hash_t
|
||||
from .errors import Errors, Warnings
|
||||
|
@ -319,8 +320,14 @@ cdef class KnowledgeBase:
|
|||
return 0.0
|
||||
|
||||
|
||||
def to_disk(self, loc):
|
||||
cdef Writer writer = Writer(loc)
|
||||
def to_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
if path.is_dir():
|
||||
raise ValueError(Errors.E928.format(loc=path))
|
||||
if not path.parent.exists():
|
||||
path.parent.mkdir(parents=True)
|
||||
|
||||
cdef Writer writer = Writer(path)
|
||||
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
||||
|
||||
# dumping the entity vectors in their original order
|
||||
|
@ -359,7 +366,13 @@ cdef class KnowledgeBase:
|
|||
|
||||
writer.close()
|
||||
|
||||
cpdef from_disk(self, loc):
|
||||
def from_disk(self, path):
|
||||
path = util.ensure_path(path)
|
||||
if path.is_dir():
|
||||
raise ValueError(Errors.E928.format(loc=path))
|
||||
if not path.exists():
|
||||
raise ValueError(Errors.E929.format(loc=path))
|
||||
|
||||
cdef hash_t entity_hash
|
||||
cdef hash_t alias_hash
|
||||
cdef int64_t entry_index
|
||||
|
@ -369,7 +382,7 @@ cdef class KnowledgeBase:
|
|||
cdef AliasC alias
|
||||
cdef float vector_element
|
||||
|
||||
cdef Reader reader = Reader(loc)
|
||||
cdef Reader reader = Reader(path)
|
||||
|
||||
# STEP 0: load header and initialize KB
|
||||
cdef int64_t nr_entities
|
||||
|
@ -450,16 +463,13 @@ cdef class KnowledgeBase:
|
|||
|
||||
|
||||
cdef class Writer:
|
||||
def __init__(self, object loc):
|
||||
if isinstance(loc, Path):
|
||||
loc = bytes(loc)
|
||||
if path.exists(loc):
|
||||
if path.isdir(loc):
|
||||
raise ValueError(Errors.E928.format(loc=loc))
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
def __init__(self, path):
|
||||
assert isinstance(path, Path)
|
||||
content = bytes(path)
|
||||
cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content
|
||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||
if not self._fp:
|
||||
raise IOError(Errors.E146.format(path=loc))
|
||||
raise IOError(Errors.E146.format(path=path))
|
||||
fseek(self._fp, 0, 0)
|
||||
|
||||
def close(self):
|
||||
|
@ -496,14 +506,9 @@ cdef class Writer:
|
|||
|
||||
|
||||
cdef class Reader:
|
||||
def __init__(self, object loc):
|
||||
if isinstance(loc, Path):
|
||||
loc = bytes(loc)
|
||||
if not path.exists(loc):
|
||||
raise ValueError(Errors.E929.format(loc=loc))
|
||||
if path.isdir(loc):
|
||||
raise ValueError(Errors.E928.format(loc=loc))
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
def __init__(self, path):
|
||||
content = bytes(path)
|
||||
cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content
|
||||
self._fp = fopen(<char*>bytes_loc, 'rb')
|
||||
if not self._fp:
|
||||
PyErr_SetFromErrno(IOError)
|
||||
|
|
|
@ -144,6 +144,29 @@ def test_kb_empty(nlp):
|
|||
entity_linker.begin_training(lambda: [])
|
||||
|
||||
|
||||
def test_kb_serialize(nlp):
|
||||
"""Test serialization of the KB"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
with make_tempdir() as d:
|
||||
# normal read-write behaviour
|
||||
mykb.to_disk(d / "kb")
|
||||
mykb.from_disk(d / "kb")
|
||||
mykb.to_disk(d / "kb.file")
|
||||
mykb.from_disk(d / "kb.file")
|
||||
mykb.to_disk(d / "new" / "kb")
|
||||
mykb.from_disk(d / "new" / "kb")
|
||||
# allow overwriting an existing file
|
||||
mykb.to_disk(d / "kb.file")
|
||||
with pytest.raises(ValueError):
|
||||
# can not write to a directory
|
||||
mykb.to_disk(d)
|
||||
with pytest.raises(ValueError):
|
||||
# can not read from a directory
|
||||
mykb.from_disk(d)
|
||||
with pytest.raises(ValueError):
|
||||
# can not read from an unknown file
|
||||
mykb.from_disk(d / "unknown" / "kb")
|
||||
|
||||
def test_candidate_generation(nlp):
|
||||
"""Test correct candidate generation"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
|
Loading…
Reference in New Issue
Block a user