Implement StringStore serialization, and update tests

This commit is contained in:
Matthew Honnibal 2017-05-22 12:38:00 +02:00
parent aae97f00e9
commit d8bb5bb959
2 changed files with 31 additions and 18 deletions

View File

@ -7,9 +7,12 @@ from libc.string cimport memcpy
from libc.stdint cimport uint64_t, uint32_t
from murmurhash.mrmr cimport hash64, hash32
from preshed.maps cimport map_iter, key_t
from libc.stdint cimport uint32_t
import ujson
import dill
from .typedefs cimport hash_t
from libc.stdint cimport uint32_t
from . import util
cpdef hash_t hash_string(unicode string) except 0:
@ -92,14 +95,6 @@ cdef class StringStore:
def __get__(self):
return self.size -1
def __reduce__(self):
# TODO: OOV words, for the is_frozen stuff?
if self.is_frozen:
raise NotImplementedError(
"Currently missing support for pickling StringStore when "
"is_frozen=True")
return (StringStore, (list(self),))
def __len__(self):
"""The number of strings in the store.
@ -186,7 +181,10 @@ cdef class StringStore:
path (unicode or Path): A path to a directory, which will be created if
it doesn't exist. Paths may be either strings or `Path`-like objects.
"""
raise NotImplementedError()
path = util.ensure_path(path)
strings = list(self)
with path.open('w') as file_:
ujson.dump(strings, file_)
def from_disk(self, path):
"""Loads state from a directory. Modifies the object in place and
@ -196,7 +194,11 @@ cdef class StringStore:
strings or `Path`-like objects.
RETURNS (StringStore): The modified `StringStore` object.
"""
raise NotImplementedError()
path = util.ensure_path(path)
with path.open('r') as file_:
strings = ujson.load(file_)
self._reset_and_load(strings)
return self
def to_bytes(self, **exclude):
"""Serialize the current state to a binary string.
@ -204,7 +206,7 @@ cdef class StringStore:
**exclude: Named attributes to prevent from being serialized.
RETURNS (bytes): The serialized form of the `StringStore` object.
"""
raise NotImplementedError()
return ujson.dumps(list(self))
def from_bytes(self, bytes_data, **exclude):
"""Load state from a binary string.
@ -213,7 +215,9 @@ cdef class StringStore:
**exclude: Named attributes to prevent from being loaded.
RETURNS (StringStore): The `StringStore` object.
"""
raise NotImplementedError()
strings = ujson.loads(bytes_data)
self._reset_and_load(strings)
return self
def set_frozen(self, bint is_frozen):
# TODO
@ -222,6 +226,17 @@ cdef class StringStore:
def flush_oov(self):
self._oov = PreshMap()
def _reset_and_load(self, strings, freeze=False):
self.mem = Pool()
self._map = PreshMap()
self._oov = PreshMap()
self._resize_at = 10000
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
self.size = 1
for string in strings:
_ = self[string]
self.is_frozen = freeze
cdef const Utf8Str* intern_unicode(self, unicode py_string):
# 0 means missing, but we don't bother offsetting the index.
cdef bytes byte_string = py_string.encode('utf8')

View File

@ -69,10 +69,8 @@ def test_stringstore_massive_strings(stringstore):
@pytest.mark.parametrize('text', ["qqqqq"])
def test_stringstore_dump_load(stringstore, text_file, text):
def test_stringstore_to_bytes(stringstore, text):
store = stringstore[text]
stringstore.dump(text_file)
text_file.seek(0)
new_stringstore = StringStore()
new_stringstore.load(text_file)
serialized = stringstore.to_bytes()
new_stringstore = StringStore().from_bytes(serialized)
assert new_stringstore[store] == text