mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 01:34:30 +03:00
Workaround Issue #285: Allow the StringStore to be 'frozen', in which case strings will be pushed into an OOV map. We can then flush this OOV map, freeing all of the OOV strings.
This commit is contained in:
parent
d3a617aa99
commit
d8134817ff
|
@ -1,3 +1,4 @@
|
||||||
|
# cython: infer_types=True
|
||||||
from __future__ import unicode_literals, absolute_import
|
from __future__ import unicode_literals, absolute_import
|
||||||
|
|
||||||
cimport cython
|
cimport cython
|
||||||
|
@ -71,12 +72,14 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except
|
||||||
|
|
||||||
cdef class StringStore:
|
cdef class StringStore:
|
||||||
'''Map strings to and from integer IDs.'''
|
'''Map strings to and from integer IDs.'''
|
||||||
def __init__(self, strings=None):
|
def __init__(self, strings=None, freeze=False):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self._map = PreshMap()
|
self._map = PreshMap()
|
||||||
|
self._oov = PreshMap()
|
||||||
self._resize_at = 10000
|
self._resize_at = 10000
|
||||||
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
||||||
self.size = 1
|
self.size = 1
|
||||||
|
self.is_frozen = False
|
||||||
if strings is not None:
|
if strings is not None:
|
||||||
for string in strings:
|
for string in strings:
|
||||||
_ = self[string]
|
_ = self[string]
|
||||||
|
@ -89,33 +92,37 @@ cdef class StringStore:
|
||||||
return self.size-1
|
return self.size-1
|
||||||
|
|
||||||
def __getitem__(self, object string_or_id):
|
def __getitem__(self, object string_or_id):
|
||||||
|
if isinstance(string_or_id, basestring) and len(string_or_id) == 0:
|
||||||
|
return 0
|
||||||
|
elif string_or_id == 0:
|
||||||
|
return u''
|
||||||
|
|
||||||
cdef bytes byte_string
|
cdef bytes byte_string
|
||||||
cdef const Utf8Str* utf8str
|
cdef const Utf8Str* utf8str
|
||||||
cdef unsigned int int_id
|
cdef uint64_t int_id
|
||||||
|
|
||||||
if isinstance(string_or_id, (int, long)):
|
if isinstance(string_or_id, (int, long)):
|
||||||
try:
|
int_id = string_or_id
|
||||||
int_id = string_or_id
|
if int_id < <uint64_t>self.size:
|
||||||
except OverflowError:
|
return _decode(&self.c[int_id])
|
||||||
raise IndexError(string_or_id)
|
else:
|
||||||
if int_id == 0:
|
utf8str = <Utf8Str*>self._oov.get(int_id)
|
||||||
return u''
|
if utf8str is not NULL:
|
||||||
elif int_id >= <uint64_t>self.size:
|
return _decode(utf8str)
|
||||||
raise IndexError(string_or_id)
|
else:
|
||||||
utf8str = &self.c[int_id]
|
raise IndexError(string_or_id)
|
||||||
return _decode(utf8str)
|
elif isinstance(string_or_id, basestring):
|
||||||
elif isinstance(string_or_id, bytes):
|
if isinstance(string_or_id, bytes):
|
||||||
byte_string = <bytes>string_or_id
|
byte_string = <bytes>string_or_id
|
||||||
if len(byte_string) == 0:
|
else:
|
||||||
return 0
|
byte_string = (<unicode>string_or_id).encode('utf8')
|
||||||
utf8str = self._intern_utf8(byte_string, len(byte_string))
|
utf8str = self._intern_utf8(byte_string, len(byte_string))
|
||||||
return utf8str - self.c
|
if utf8str is NULL:
|
||||||
elif isinstance(string_or_id, unicode):
|
# TODO: We could get unlucky here, and hash into a value that
|
||||||
if len(<unicode>string_or_id) == 0:
|
# collides with the 'real' strings. All we have to do is offset
|
||||||
return 0
|
# I think?
|
||||||
byte_string = (<unicode>string_or_id).encode('utf8')
|
return _hash_utf8(byte_string, len(byte_string))
|
||||||
utf8str = self._intern_utf8(byte_string, len(byte_string))
|
else:
|
||||||
return utf8str - self.c
|
return utf8str - self.c
|
||||||
else:
|
else:
|
||||||
raise TypeError(type(string_or_id))
|
raise TypeError(type(string_or_id))
|
||||||
|
|
||||||
|
@ -129,6 +136,7 @@ cdef class StringStore:
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.size):
|
for i in range(self.size):
|
||||||
yield _decode(&self.c[i]) if i > 0 else u''
|
yield _decode(&self.c[i]) if i > 0 else u''
|
||||||
|
# TODO: Iterate OOV here?
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
strings = [""]
|
strings = [""]
|
||||||
|
@ -138,18 +146,36 @@ cdef class StringStore:
|
||||||
strings.append(py_string)
|
strings.append(py_string)
|
||||||
return (StringStore, (strings,), None, None, None)
|
return (StringStore, (strings,), None, None, None)
|
||||||
|
|
||||||
cdef const Utf8Str* intern(self, unicode py_string) except NULL:
|
def set_frozen(self, bint is_frozen):
|
||||||
|
# TODO
|
||||||
|
self.is_frozen = is_frozen
|
||||||
|
|
||||||
|
def flush_oov(self):
|
||||||
|
self._oov = PreshMap()
|
||||||
|
|
||||||
|
cdef const Utf8Str* intern_unicode(self, unicode py_string):
|
||||||
# 0 means missing, but we don't bother offsetting the index.
|
# 0 means missing, but we don't bother offsetting the index.
|
||||||
cdef bytes byte_string = py_string.encode('utf8')
|
cdef bytes byte_string = py_string.encode('utf8')
|
||||||
return self._intern_utf8(byte_string, len(byte_string))
|
return self._intern_utf8(byte_string, len(byte_string))
|
||||||
|
|
||||||
@cython.final
|
@cython.final
|
||||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length) except NULL:
|
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length):
|
||||||
|
# TODO: This function's API/behaviour is an unholy mess...
|
||||||
# 0 means missing, but we don't bother offsetting the index.
|
# 0 means missing, but we don't bother offsetting the index.
|
||||||
cdef hash_t key = _hash_utf8(utf8_string, length)
|
cdef hash_t key = _hash_utf8(utf8_string, length)
|
||||||
value = <Utf8Str*>self._map.get(key)
|
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
||||||
if value is not NULL:
|
if value is not NULL:
|
||||||
return value
|
return value
|
||||||
|
value = <Utf8Str*>self._oov.get(key)
|
||||||
|
if value is not NULL:
|
||||||
|
return value
|
||||||
|
if self.is_frozen:
|
||||||
|
# Important: Make the OOV store own the memory. That way it's trivial
|
||||||
|
# to flush them all.
|
||||||
|
value = <Utf8Str*>self._oov.mem.alloc(1, sizeof(Utf8Str))
|
||||||
|
value[0] = _allocate(self._oov.mem, <unsigned char*>utf8_string, length)
|
||||||
|
self._oov.set(key, value)
|
||||||
|
return NULL
|
||||||
|
|
||||||
if self.size == self._resize_at:
|
if self.size == self._resize_at:
|
||||||
self._realloc()
|
self._realloc()
|
||||||
|
@ -162,6 +188,7 @@ cdef class StringStore:
|
||||||
string_data = json.dumps(list(self))
|
string_data = json.dumps(list(self))
|
||||||
if not isinstance(string_data, unicode):
|
if not isinstance(string_data, unicode):
|
||||||
string_data = string_data.decode('utf8')
|
string_data = string_data.decode('utf8')
|
||||||
|
# TODO: OOV?
|
||||||
file_.write(string_data)
|
file_.write(string_data)
|
||||||
|
|
||||||
def load(self, file_):
|
def load(self, file_):
|
||||||
|
@ -173,7 +200,7 @@ cdef class StringStore:
|
||||||
# explicit None/len check instead of simple truth testing
|
# explicit None/len check instead of simple truth testing
|
||||||
# (bug in Cython <= 0.23.4)
|
# (bug in Cython <= 0.23.4)
|
||||||
if string is not None and len(string):
|
if string is not None and len(string):
|
||||||
self.intern(string)
|
self.intern_unicode(string)
|
||||||
|
|
||||||
def _realloc(self):
|
def _realloc(self):
|
||||||
# We want to map straight to pointers, but they'll be invalidated if
|
# We want to map straight to pointers, but they'll be invalidated if
|
||||||
|
|
Loading…
Reference in New Issue
Block a user