mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
fix KB IO (#6118)
This commit is contained in:
parent
d53c84b6d6
commit
e0e793be4d
|
@ -140,7 +140,6 @@ cdef class KnowledgeBase:
|
||||||
self._entries.push_back(entry)
|
self._entries.push_back(entry)
|
||||||
self._aliases_table.push_back(alias)
|
self._aliases_table.push_back(alias)
|
||||||
|
|
||||||
cpdef from_disk(self, loc)
|
|
||||||
cpdef set_entities(self, entity_list, freq_list, vector_list)
|
cpdef set_entities(self, entity_list, freq_list, vector_list)
|
||||||
|
|
||||||
|
|
||||||
|
|
47
spacy/kb.pyx
47
spacy/kb.pyx
|
@ -9,7 +9,8 @@ from libcpp.vector cimport vector
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import warnings
|
import warnings
|
||||||
from os import path
|
|
||||||
|
from spacy import util
|
||||||
|
|
||||||
from .typedefs cimport hash_t
|
from .typedefs cimport hash_t
|
||||||
from .errors import Errors, Warnings
|
from .errors import Errors, Warnings
|
||||||
|
@ -319,8 +320,14 @@ cdef class KnowledgeBase:
|
||||||
return 0.0
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def to_disk(self, loc):
|
def to_disk(self, path):
|
||||||
cdef Writer writer = Writer(loc)
|
path = util.ensure_path(path)
|
||||||
|
if path.is_dir():
|
||||||
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
if not path.parent.exists():
|
||||||
|
path.parent.mkdir(parents=True)
|
||||||
|
|
||||||
|
cdef Writer writer = Writer(path)
|
||||||
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
||||||
|
|
||||||
# dumping the entity vectors in their original order
|
# dumping the entity vectors in their original order
|
||||||
|
@ -359,7 +366,13 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
writer.close()
|
writer.close()
|
||||||
|
|
||||||
cpdef from_disk(self, loc):
|
def from_disk(self, path):
|
||||||
|
path = util.ensure_path(path)
|
||||||
|
if path.is_dir():
|
||||||
|
raise ValueError(Errors.E928.format(loc=path))
|
||||||
|
if not path.exists():
|
||||||
|
raise ValueError(Errors.E929.format(loc=path))
|
||||||
|
|
||||||
cdef hash_t entity_hash
|
cdef hash_t entity_hash
|
||||||
cdef hash_t alias_hash
|
cdef hash_t alias_hash
|
||||||
cdef int64_t entry_index
|
cdef int64_t entry_index
|
||||||
|
@ -369,7 +382,7 @@ cdef class KnowledgeBase:
|
||||||
cdef AliasC alias
|
cdef AliasC alias
|
||||||
cdef float vector_element
|
cdef float vector_element
|
||||||
|
|
||||||
cdef Reader reader = Reader(loc)
|
cdef Reader reader = Reader(path)
|
||||||
|
|
||||||
# STEP 0: load header and initialize KB
|
# STEP 0: load header and initialize KB
|
||||||
cdef int64_t nr_entities
|
cdef int64_t nr_entities
|
||||||
|
@ -450,16 +463,13 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
|
|
||||||
cdef class Writer:
|
cdef class Writer:
|
||||||
def __init__(self, object loc):
|
def __init__(self, path):
|
||||||
if isinstance(loc, Path):
|
assert isinstance(path, Path)
|
||||||
loc = bytes(loc)
|
content = bytes(path)
|
||||||
if path.exists(loc):
|
cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content
|
||||||
if path.isdir(loc):
|
|
||||||
raise ValueError(Errors.E928.format(loc=loc))
|
|
||||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
|
||||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||||
if not self._fp:
|
if not self._fp:
|
||||||
raise IOError(Errors.E146.format(path=loc))
|
raise IOError(Errors.E146.format(path=path))
|
||||||
fseek(self._fp, 0, 0)
|
fseek(self._fp, 0, 0)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -496,14 +506,9 @@ cdef class Writer:
|
||||||
|
|
||||||
|
|
||||||
cdef class Reader:
|
cdef class Reader:
|
||||||
def __init__(self, object loc):
|
def __init__(self, path):
|
||||||
if isinstance(loc, Path):
|
content = bytes(path)
|
||||||
loc = bytes(loc)
|
cdef bytes bytes_loc = content.encode('utf8') if type(content) == unicode else content
|
||||||
if not path.exists(loc):
|
|
||||||
raise ValueError(Errors.E929.format(loc=loc))
|
|
||||||
if path.isdir(loc):
|
|
||||||
raise ValueError(Errors.E928.format(loc=loc))
|
|
||||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
|
||||||
self._fp = fopen(<char*>bytes_loc, 'rb')
|
self._fp = fopen(<char*>bytes_loc, 'rb')
|
||||||
if not self._fp:
|
if not self._fp:
|
||||||
PyErr_SetFromErrno(IOError)
|
PyErr_SetFromErrno(IOError)
|
||||||
|
|
|
@ -144,6 +144,29 @@ def test_kb_empty(nlp):
|
||||||
entity_linker.begin_training(lambda: [])
|
entity_linker.begin_training(lambda: [])
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_serialize(nlp):
|
||||||
|
"""Test serialization of the KB"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
with make_tempdir() as d:
|
||||||
|
# normal read-write behaviour
|
||||||
|
mykb.to_disk(d / "kb")
|
||||||
|
mykb.from_disk(d / "kb")
|
||||||
|
mykb.to_disk(d / "kb.file")
|
||||||
|
mykb.from_disk(d / "kb.file")
|
||||||
|
mykb.to_disk(d / "new" / "kb")
|
||||||
|
mykb.from_disk(d / "new" / "kb")
|
||||||
|
# allow overwriting an existing file
|
||||||
|
mykb.to_disk(d / "kb.file")
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# can not write to a directory
|
||||||
|
mykb.to_disk(d)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# can not read from a directory
|
||||||
|
mykb.from_disk(d)
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
# can not read from an unknown file
|
||||||
|
mykb.from_disk(d / "unknown" / "kb")
|
||||||
|
|
||||||
def test_candidate_generation(nlp):
|
def test_candidate_generation(nlp):
|
||||||
"""Test correct candidate generation"""
|
"""Test correct candidate generation"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user