diff --git a/spacy/strings.pyx b/spacy/strings.pyx index 2e81bd87b..02eed4d2f 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -1,3 +1,4 @@ +# cython: infer_types=True from __future__ import unicode_literals, absolute_import cimport cython @@ -71,12 +72,14 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except cdef class StringStore: '''Map strings to and from integer IDs.''' - def __init__(self, strings=None): + def __init__(self, strings=None, freeze=False): self.mem = Pool() self._map = PreshMap() + self._oov = PreshMap() self._resize_at = 10000 self.c = self.mem.alloc(self._resize_at, sizeof(Utf8Str)) self.size = 1 + self.is_frozen = False if strings is not None: for string in strings: _ = self[string] @@ -89,33 +92,37 @@ cdef class StringStore: return self.size-1 def __getitem__(self, object string_or_id): + if isinstance(string_or_id, basestring) and len(string_or_id) == 0: + return 0 + elif string_or_id == 0: + return u'' + cdef bytes byte_string cdef const Utf8Str* utf8str - cdef unsigned int int_id - + cdef uint64_t int_id if isinstance(string_or_id, (int, long)): - try: - int_id = string_or_id - except OverflowError: - raise IndexError(string_or_id) - if int_id == 0: - return u'' - elif int_id >= self.size: - raise IndexError(string_or_id) - utf8str = &self.c[int_id] - return _decode(utf8str) - elif isinstance(string_or_id, bytes): - byte_string = string_or_id - if len(byte_string) == 0: - return 0 + int_id = string_or_id + if int_id < self.size: + return _decode(&self.c[int_id]) + else: + utf8str = self._oov.get(int_id) + if utf8str is not NULL: + return _decode(utf8str) + else: + raise IndexError(string_or_id) + elif isinstance(string_or_id, basestring): + if isinstance(string_or_id, bytes): + byte_string = string_or_id + else: + byte_string = (string_or_id).encode('utf8') utf8str = self._intern_utf8(byte_string, len(byte_string)) - return utf8str - self.c - elif isinstance(string_or_id, unicode): - if len(string_or_id) == 0: - return 0 - byte_string = (string_or_id).encode('utf8') - utf8str = self._intern_utf8(byte_string, len(byte_string)) - return utf8str - self.c + if utf8str is NULL: + # TODO: We could get unlucky here, and hash into a value that + # collides with the 'real' strings. All we have to do is offset + # I think? + return _hash_utf8(byte_string, len(byte_string)) + else: + return utf8str - self.c else: raise TypeError(type(string_or_id)) @@ -129,6 +136,7 @@ cdef class StringStore: cdef int i for i in range(self.size): yield _decode(&self.c[i]) if i > 0 else u'' + # TODO: Iterate OOV here? def __reduce__(self): strings = [""] @@ -138,18 +146,36 @@ cdef class StringStore: strings.append(py_string) return (StringStore, (strings,), None, None, None) - cdef const Utf8Str* intern(self, unicode py_string) except NULL: + def set_frozen(self, bint is_frozen): + # TODO + self.is_frozen = is_frozen + + def flush_oov(self): + self._oov = PreshMap() + + cdef const Utf8Str* intern_unicode(self, unicode py_string): # 0 means missing, but we don't bother offsetting the index. cdef bytes byte_string = py_string.encode('utf8') return self._intern_utf8(byte_string, len(byte_string)) @cython.final - cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length) except NULL: + cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length): + # TODO: This function's API/behaviour is an unholy mess... # 0 means missing, but we don't bother offsetting the index. cdef hash_t key = _hash_utf8(utf8_string, length) - value = self._map.get(key) + cdef Utf8Str* value = self._map.get(key) if value is not NULL: return value + value = self._oov.get(key) + if value is not NULL: + return value + if self.is_frozen: + # Important: Make the OOV store own the memory. That way it's trivial + # to flush them all. + value = self._oov.mem.alloc(1, sizeof(Utf8Str)) + value[0] = _allocate(self._oov.mem, utf8_string, length) + self._oov.set(key, value) + return NULL if self.size == self._resize_at: self._realloc() @@ -162,6 +188,7 @@ cdef class StringStore: string_data = json.dumps(list(self)) if not isinstance(string_data, unicode): string_data = string_data.decode('utf8') + # TODO: OOV? file_.write(string_data) def load(self, file_): @@ -173,7 +200,7 @@ cdef class StringStore: # explicit None/len check instead of simple truth testing # (bug in Cython <= 0.23.4) if string is not None and len(string): - self.intern(string) + self.intern_unicode(string) def _realloc(self): # We want to map straight to pointers, but they'll be invalidated if