Workaround Issue #285: Allow the StringStore to be 'frozen', in which case strings will be pushed into an OOV map. We can then flush this OOV map, freeing all of the OOV strings.

This commit is contained in:
Matthew Honnibal 2016-10-24 13:49:03 +02:00
parent d3a617aa99
commit d8134817ff

View File

@ -1,3 +1,4 @@
# cython: infer_types=True
from __future__ import unicode_literals, absolute_import from __future__ import unicode_literals, absolute_import
cimport cython cimport cython
@ -71,12 +72,14 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except
cdef class StringStore: cdef class StringStore:
'''Map strings to and from integer IDs.''' '''Map strings to and from integer IDs.'''
def __init__(self, strings=None): def __init__(self, strings=None, freeze=False):
self.mem = Pool() self.mem = Pool()
self._map = PreshMap() self._map = PreshMap()
self._oov = PreshMap()
self._resize_at = 10000 self._resize_at = 10000
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str)) self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
self.size = 1 self.size = 1
self.is_frozen = False
if strings is not None: if strings is not None:
for string in strings: for string in strings:
_ = self[string] _ = self[string]
@ -89,33 +92,37 @@ cdef class StringStore:
return self.size-1 return self.size-1
def __getitem__(self, object string_or_id): def __getitem__(self, object string_or_id):
if isinstance(string_or_id, basestring) and len(string_or_id) == 0:
return 0
elif string_or_id == 0:
return u''
cdef bytes byte_string cdef bytes byte_string
cdef const Utf8Str* utf8str cdef const Utf8Str* utf8str
cdef unsigned int int_id cdef uint64_t int_id
if isinstance(string_or_id, (int, long)): if isinstance(string_or_id, (int, long)):
try: int_id = string_or_id
int_id = string_or_id if int_id < <uint64_t>self.size:
except OverflowError: return _decode(&self.c[int_id])
raise IndexError(string_or_id) else:
if int_id == 0: utf8str = <Utf8Str*>self._oov.get(int_id)
return u'' if utf8str is not NULL:
elif int_id >= <uint64_t>self.size: return _decode(utf8str)
raise IndexError(string_or_id) else:
utf8str = &self.c[int_id] raise IndexError(string_or_id)
return _decode(utf8str) elif isinstance(string_or_id, basestring):
elif isinstance(string_or_id, bytes): if isinstance(string_or_id, bytes):
byte_string = <bytes>string_or_id byte_string = <bytes>string_or_id
if len(byte_string) == 0: else:
return 0 byte_string = (<unicode>string_or_id).encode('utf8')
utf8str = self._intern_utf8(byte_string, len(byte_string)) utf8str = self._intern_utf8(byte_string, len(byte_string))
return utf8str - self.c if utf8str is NULL:
elif isinstance(string_or_id, unicode): # TODO: We could get unlucky here, and hash into a value that
if len(<unicode>string_or_id) == 0: # collides with the 'real' strings. All we have to do is offset
return 0 # I think?
byte_string = (<unicode>string_or_id).encode('utf8') return _hash_utf8(byte_string, len(byte_string))
utf8str = self._intern_utf8(byte_string, len(byte_string)) else:
return utf8str - self.c return utf8str - self.c
else: else:
raise TypeError(type(string_or_id)) raise TypeError(type(string_or_id))
@ -129,6 +136,7 @@ cdef class StringStore:
cdef int i cdef int i
for i in range(self.size): for i in range(self.size):
yield _decode(&self.c[i]) if i > 0 else u'' yield _decode(&self.c[i]) if i > 0 else u''
# TODO: Iterate OOV here?
def __reduce__(self): def __reduce__(self):
strings = [""] strings = [""]
@ -138,18 +146,36 @@ cdef class StringStore:
strings.append(py_string) strings.append(py_string)
return (StringStore, (strings,), None, None, None) return (StringStore, (strings,), None, None, None)
cdef const Utf8Str* intern(self, unicode py_string) except NULL: def set_frozen(self, bint is_frozen):
# TODO
self.is_frozen = is_frozen
def flush_oov(self):
self._oov = PreshMap()
cdef const Utf8Str* intern_unicode(self, unicode py_string):
# 0 means missing, but we don't bother offsetting the index. # 0 means missing, but we don't bother offsetting the index.
cdef bytes byte_string = py_string.encode('utf8') cdef bytes byte_string = py_string.encode('utf8')
return self._intern_utf8(byte_string, len(byte_string)) return self._intern_utf8(byte_string, len(byte_string))
@cython.final @cython.final
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length) except NULL: cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length):
# TODO: This function's API/behaviour is an unholy mess...
# 0 means missing, but we don't bother offsetting the index. # 0 means missing, but we don't bother offsetting the index.
cdef hash_t key = _hash_utf8(utf8_string, length) cdef hash_t key = _hash_utf8(utf8_string, length)
value = <Utf8Str*>self._map.get(key) cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
if value is not NULL: if value is not NULL:
return value return value
value = <Utf8Str*>self._oov.get(key)
if value is not NULL:
return value
if self.is_frozen:
# Important: Make the OOV store own the memory. That way it's trivial
# to flush them all.
value = <Utf8Str*>self._oov.mem.alloc(1, sizeof(Utf8Str))
value[0] = _allocate(self._oov.mem, <unsigned char*>utf8_string, length)
self._oov.set(key, value)
return NULL
if self.size == self._resize_at: if self.size == self._resize_at:
self._realloc() self._realloc()
@ -162,6 +188,7 @@ cdef class StringStore:
string_data = json.dumps(list(self)) string_data = json.dumps(list(self))
if not isinstance(string_data, unicode): if not isinstance(string_data, unicode):
string_data = string_data.decode('utf8') string_data = string_data.decode('utf8')
# TODO: OOV?
file_.write(string_data) file_.write(string_data)
def load(self, file_): def load(self, file_):
@ -173,7 +200,7 @@ cdef class StringStore:
# explicit None/len check instead of simple truth testing # explicit None/len check instead of simple truth testing
# (bug in Cython <= 0.23.4) # (bug in Cython <= 0.23.4)
if string is not None and len(string): if string is not None and len(string):
self.intern(string) self.intern_unicode(string)
def _realloc(self): def _realloc(self):
# We want to map straight to pointers, but they'll be invalidated if # We want to map straight to pointers, but they'll be invalidated if