Workaround Issue #285: Allow the StringStore to be 'frozen', in which case strings will be pushed into an OOV map. We can then flush this OOV map, freeing all of the OOV strings.

This commit is contained in:
Matthew Honnibal 2016-10-24 13:49:03 +02:00
parent d3a617aa99
commit d8134817ff

View File

@ -1,3 +1,4 @@
# cython: infer_types=True
from __future__ import unicode_literals, absolute_import
cimport cython
@ -71,12 +72,14 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except
cdef class StringStore:
'''Map strings to and from integer IDs.'''
def __init__(self, strings=None):
def __init__(self, strings=None, freeze=False):
self.mem = Pool()
self._map = PreshMap()
self._oov = PreshMap()
self._resize_at = 10000
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
self.size = 1
self.is_frozen = False
if strings is not None:
for string in strings:
_ = self[string]
@ -89,32 +92,36 @@ cdef class StringStore:
return self.size-1
def __getitem__(self, object string_or_id):
if isinstance(string_or_id, basestring) and len(string_or_id) == 0:
return 0
elif string_or_id == 0:
return u''
cdef bytes byte_string
cdef const Utf8Str* utf8str
cdef unsigned int int_id
cdef uint64_t int_id
if isinstance(string_or_id, (int, long)):
try:
int_id = string_or_id
except OverflowError:
raise IndexError(string_or_id)
if int_id == 0:
return u''
elif int_id >= <uint64_t>self.size:
raise IndexError(string_or_id)
utf8str = &self.c[int_id]
if int_id < <uint64_t>self.size:
return _decode(&self.c[int_id])
else:
utf8str = <Utf8Str*>self._oov.get(int_id)
if utf8str is not NULL:
return _decode(utf8str)
elif isinstance(string_or_id, bytes):
else:
raise IndexError(string_or_id)
elif isinstance(string_or_id, basestring):
if isinstance(string_or_id, bytes):
byte_string = <bytes>string_or_id
if len(byte_string) == 0:
return 0
utf8str = self._intern_utf8(byte_string, len(byte_string))
return utf8str - self.c
elif isinstance(string_or_id, unicode):
if len(<unicode>string_or_id) == 0:
return 0
else:
byte_string = (<unicode>string_or_id).encode('utf8')
utf8str = self._intern_utf8(byte_string, len(byte_string))
if utf8str is NULL:
# TODO: We could get unlucky here, and hash into a value that
# collides with the 'real' strings. All we have to do is offset
# I think?
return _hash_utf8(byte_string, len(byte_string))
else:
return utf8str - self.c
else:
raise TypeError(type(string_or_id))
@ -129,6 +136,7 @@ cdef class StringStore:
cdef int i
for i in range(self.size):
yield _decode(&self.c[i]) if i > 0 else u''
# TODO: Iterate OOV here?
def __reduce__(self):
strings = [""]
@ -138,18 +146,36 @@ cdef class StringStore:
strings.append(py_string)
return (StringStore, (strings,), None, None, None)
cdef const Utf8Str* intern(self, unicode py_string) except NULL:
def set_frozen(self, bint is_frozen):
# TODO
self.is_frozen = is_frozen
def flush_oov(self):
self._oov = PreshMap()
cdef const Utf8Str* intern_unicode(self, unicode py_string):
# 0 means missing, but we don't bother offsetting the index.
cdef bytes byte_string = py_string.encode('utf8')
return self._intern_utf8(byte_string, len(byte_string))
@cython.final
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length) except NULL:
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length):
# TODO: This function's API/behaviour is an unholy mess...
# 0 means missing, but we don't bother offsetting the index.
cdef hash_t key = _hash_utf8(utf8_string, length)
value = <Utf8Str*>self._map.get(key)
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
if value is not NULL:
return value
value = <Utf8Str*>self._oov.get(key)
if value is not NULL:
return value
if self.is_frozen:
# Important: Make the OOV store own the memory. That way it's trivial
# to flush them all.
value = <Utf8Str*>self._oov.mem.alloc(1, sizeof(Utf8Str))
value[0] = _allocate(self._oov.mem, <unsigned char*>utf8_string, length)
self._oov.set(key, value)
return NULL
if self.size == self._resize_at:
self._realloc()
@ -162,6 +188,7 @@ cdef class StringStore:
string_data = json.dumps(list(self))
if not isinstance(string_data, unicode):
string_data = string_data.decode('utf8')
# TODO: OOV?
file_.write(string_data)
def load(self, file_):
@ -173,7 +200,7 @@ cdef class StringStore:
# explicit None/len check instead of simple truth testing
# (bug in Cython <= 0.23.4)
if string is not None and len(string):
self.intern(string)
self.intern_unicode(string)
def _realloc(self):
# We want to map straight to pointers, but they'll be invalidated if