Add support for pickling StringStore.

This commit is contained in:
Matthew Honnibal 2017-03-07 17:15:18 +01:00
parent 4e75e74247
commit 5de7e712b7
2 changed files with 38 additions and 16 deletions

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals, absolute_import
cimport cython cimport cython
from libc.string cimport memcpy from libc.string cimport memcpy
from libc.stdint cimport uint64_t from libc.stdint cimport uint64_t, uint32_t
from murmurhash.mrmr cimport hash64, hash32 from murmurhash.mrmr cimport hash64, hash32
@ -12,22 +12,19 @@ from preshed.maps cimport map_iter, key_t
from .typedefs cimport hash_t from .typedefs cimport hash_t
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
try: import ujson
import ujson as json
except ImportError:
import json
cpdef hash_t hash_string(unicode string) except 0: cpdef hash_t hash_string(unicode string) except 0:
chars = string.encode('utf8') chars = string.encode('utf8')
return _hash_utf8(chars, len(chars)) return hash_utf8(chars, len(chars))
cdef hash_t _hash_utf8(char* utf8_string, int length): cdef hash_t hash_utf8(char* utf8_string, int length) nogil:
return hash64(utf8_string, length, 1) return hash64(utf8_string, length, 1)
cdef uint32_t _hash32_utf8(char* utf8_string, int length): cdef uint32_t hash32_utf8(char* utf8_string, int length) nogil:
return hash32(utf8_string, length, 1) return hash32(utf8_string, length, 1)
@ -48,11 +45,11 @@ cdef unicode _decode(const Utf8Str* string):
return string.p[i:length + i].decode('utf8') return string.p[i:length + i].decode('utf8')
cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except *: cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *:
cdef int n_length_bytes cdef int n_length_bytes
cdef int i cdef int i
cdef Utf8Str string cdef Utf8Str string
assert length != 0 cdef uint32_t ulength = length
if length < sizeof(string.s): if length < sizeof(string.s):
string.s[0] = <unsigned char>length string.s[0] = <unsigned char>length
memcpy(&string.s[1], chars, length) memcpy(&string.s[1], chars, length)
@ -98,6 +95,14 @@ cdef class StringStore:
def __get__(self): def __get__(self):
return self.size -1 return self.size -1
def __reduce__(self):
# TODO: OOV words, for the is_frozen stuff?
if self.is_frozen:
raise NotImplementedError(
"Currently missing support for pickling StringStore when "
"is_frozen=True")
return (StringStore, (list(self),))
def __len__(self): def __len__(self):
"""The number of strings in the store. """The number of strings in the store.
@ -149,7 +154,7 @@ cdef class StringStore:
# pretty bad. # pretty bad.
# We could also get unlucky here, and hash into a value that # We could also get unlucky here, and hash into a value that
# collides with the 'real' strings. # collides with the 'real' strings.
return _hash32_utf8(byte_string, len(byte_string)) return hash32_utf8(byte_string, len(byte_string))
else: else:
return utf8str - self.c return utf8str - self.c
@ -200,7 +205,7 @@ cdef class StringStore:
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length): cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length):
# TODO: This function's API/behaviour is an unholy mess... # TODO: This function's API/behaviour is an unholy mess...
# 0 means missing, but we don't bother offsetting the index. # 0 means missing, but we don't bother offsetting the index.
cdef hash_t key = _hash_utf8(utf8_string, length) cdef hash_t key = hash_utf8(utf8_string, length)
cdef Utf8Str* value = <Utf8Str*>self._map.get(key) cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
if value is not NULL: if value is not NULL:
return value return value
@ -209,7 +214,7 @@ cdef class StringStore:
return value return value
if self.is_frozen: if self.is_frozen:
# OOV store uses 32 bit hashes. Pretty ugly :( # OOV store uses 32 bit hashes. Pretty ugly :(
key32 = _hash32_utf8(utf8_string, length) key32 = hash32_utf8(utf8_string, length)
# Important: Make the OOV store own the memory. That way it's trivial # Important: Make the OOV store own the memory. That way it's trivial
# to flush them all. # to flush them all.
value = <Utf8Str*>self._oov.mem.alloc(1, sizeof(Utf8Str)) value = <Utf8Str*>self._oov.mem.alloc(1, sizeof(Utf8Str))
@ -232,7 +237,7 @@ cdef class StringStore:
Returns: Returns:
None None
""" """
string_data = json.dumps(list(self)) string_data = ujson.dumps(list(self))
if not isinstance(string_data, unicode): if not isinstance(string_data, unicode):
string_data = string_data.decode('utf8') string_data = string_data.decode('utf8')
# TODO: OOV? # TODO: OOV?
@ -246,7 +251,7 @@ cdef class StringStore:
Returns: Returns:
None None
""" """
strings = json.load(file_) strings = ujson.load(file_)
if strings == ['']: if strings == ['']:
return None return None
cdef unicode string cdef unicode string

View File

@ -0,0 +1,17 @@
from __future__ import unicode_literals
import io
import pickle
from ..strings import StringStore
def test_pickle_string_store():
sstore = StringStore()
hello = sstore['hello']
bye = sstore['bye']
bdata = pickle.dumps(sstore, protocol=-1)
unpickled = pickle.loads(bdata)
assert unpickled['hello'] == hello
assert unpickled['bye'] == bye