diff --git a/spacy/strings.pxd b/spacy/strings.pxd index d5e320642..0ad403cf1 100644 --- a/spacy/strings.pxd +++ b/spacy/strings.pxd @@ -1,4 +1,5 @@ from libc.stdint cimport int64_t +from libcpp.vector cimport vector from cymem.cymem cimport Pool from preshed.maps cimport PreshMap @@ -8,6 +9,9 @@ from .typedefs cimport attr_t, hash_t cpdef hash_t hash_string(unicode string) except 0 +cdef hash_t hash_utf8(char* utf8_string, int length) nogil + +cdef unicode decode_Utf8Str(const Utf8Str* string) ctypedef union Utf8Str: @@ -17,13 +21,11 @@ ctypedef union Utf8Str: cdef class StringStore: cdef Pool mem - cdef Utf8Str* c - cdef int64_t size cdef bint is_frozen + cdef vector[hash_t] keys cdef public PreshMap _map cdef public PreshMap _oov - cdef int64_t _resize_at cdef const Utf8Str* intern_unicode(self, unicode py_string) cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length) diff --git a/spacy/strings.pyx b/spacy/strings.pyx index b704ac789..3b5749097 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -28,7 +28,7 @@ cdef uint32_t hash32_utf8(char* utf8_string, int length) nogil: return hash32(utf8_string, length, 1) -cdef unicode _decode(const Utf8Str* string): +cdef unicode decode_Utf8Str(const Utf8Str* string): cdef int i, length if string.s[0] < sizeof(string.s) and string.s[0] != 0: return string.s[1:string.s[0]+1].decode('utf8') @@ -45,10 +45,10 @@ cdef unicode _decode(const Utf8Str* string): return string.p[i:length + i].decode('utf8') -cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *: +cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *: cdef int n_length_bytes cdef int i - cdef Utf8Str string + cdef Utf8Str* string = mem.alloc(1, sizeof(Utf8Str)) cdef uint32_t ulength = length if length < sizeof(string.s): string.s[0] = length @@ -71,9 +71,9 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, uint32_t length) ex assert string.s[0] >= sizeof(string.s) or string.s[0] == 0, string.s[0] return string - + cdef class StringStore: - """Map strings to and from integer IDs.""" + """Lookup strings by 64-bit hash""" def __init__(self, strings=None, freeze=False): """Create the StringStore. @@ -83,68 +83,56 @@ cdef class StringStore: self.mem = Pool() self._map = PreshMap() self._oov = PreshMap() - self._resize_at = 10000 - self.c = self.mem.alloc(self._resize_at, sizeof(Utf8Str)) - self.size = 1 self.is_frozen = freeze if strings is not None: for string in strings: - _ = self[string] - - property size: - def __get__(self): - return self.size -1 - - def __len__(self): - """The number of strings in the store. - - RETURNS (int): The number of strings in the store. - """ - return self.size-1 + self.add(string) def __getitem__(self, object string_or_id): - """Retrieve a string from a given integer ID, or vice versa. + """Retrieve a string from a given hash ID, or vice versa. - string_or_id (bytes or unicode or int): The value to encode. - Returns (unicode or int): The value to be retrieved. + string_or_id (bytes or unicode or uint64): The value to encode. + Returns (unicode or uint64): The value to be retrieved. """ if isinstance(string_or_id, basestring) and len(string_or_id) == 0: return 0 elif string_or_id == 0: return u'' - cdef bytes byte_string - cdef const Utf8Str* utf8str - cdef uint64_t int_id - cdef uint32_t oov_id - if isinstance(string_or_id, (int, long)): - int_id = string_or_id - oov_id = string_or_id - if int_id < self.size: - return _decode(&self.c[int_id]) - else: - utf8str = self._oov.get(oov_id) - if utf8str is not NULL: - return _decode(utf8str) - else: - raise IndexError(string_or_id) + cdef hash_t key + + if isinstance(string_or_id, unicode): + key = hash_string(string_or_id) + return key + elif isinstance(string_or_id, bytes): + key = hash_utf8(string_or_id, len(string_or_id)) + return key else: - if isinstance(string_or_id, bytes): - byte_string = string_or_id - elif isinstance(string_or_id, unicode): - byte_string = (string_or_id).encode('utf8') - else: - raise TypeError(type(string_or_id)) - utf8str = self._intern_utf8(byte_string, len(byte_string)) + key = string_or_id + utf8str = self._map.get(key) if utf8str is NULL: - # TODO: We need to use 32 bit here, for compatibility with the - # vocabulary values. This makes birthday paradox probabilities - # pretty bad. - # We could also get unlucky here, and hash into a value that - # collides with the 'real' strings. - return hash32_utf8(byte_string, len(byte_string)) + raise KeyError(string_or_id) else: - return utf8str - self.c + return decode_Utf8Str(utf8str) + + def add(self, string): + if isinstance(string, unicode): + key = hash_string(string) + self.intern_unicode(string) + elif isinstance(string, bytes): + key = hash_utf8(string, len(string)) + self._intern_utf8(string, len(string)) + else: + raise TypeError( + "Can only add unicode or bytes. Got type: %s" % type(string)) + return key + + def __len__(self): + """The number of strings in the store. + + RETURNS (int): The number of strings in the store. + """ + return self.keys.size() def __contains__(self, unicode string not None): """Check whether a string is in the store. @@ -163,16 +151,15 @@ cdef class StringStore: YIELDS (unicode): A string in the store. """ cdef int i - for i in range(self.size): - yield _decode(&self.c[i]) if i > 0 else u'' + cdef hash_t key + for i in range(self.keys.size()): + key = self.keys[i] + utf8str = self._map.get(key) + yield decode_Utf8Str(utf8str) # TODO: Iterate OOV here? def __reduce__(self): - strings = [""] - for i in range(1, self.size): - string = &self.c[i] - py_string = _decode(string) - strings.append(py_string) + strings = list(self) return (StringStore, (strings,), None, None, None) def to_disk(self, path): @@ -230,11 +217,9 @@ cdef class StringStore: self.mem = Pool() self._map = PreshMap() self._oov = PreshMap() - self._resize_at = 10000 - self.c = self.mem.alloc(self._resize_at, sizeof(Utf8Str)) - self.size = 1 + self.keys.clear() for string in strings: - _ = self[string] + self.add(string) self.is_frozen = freeze cdef const Utf8Str* intern_unicode(self, unicode py_string): @@ -258,39 +243,11 @@ cdef class StringStore: key32 = hash32_utf8(utf8_string, length) # Important: Make the OOV store own the memory. That way it's trivial # to flush them all. - value = self._oov.mem.alloc(1, sizeof(Utf8Str)) - value[0] = _allocate(self._oov.mem, utf8_string, length) + value = _allocate(self._oov.mem, utf8_string, length) self._oov.set(key32, value) return NULL - if self.size == self._resize_at: - self._realloc() - self.c[self.size] = _allocate(self.mem, utf8_string, length) - self._map.set(key, &self.c[self.size]) - self.size += 1 - return &self.c[self.size-1] - - def _realloc(self): - # We want to map straight to pointers, but they'll be invalidated if - # we resize our array. So, first we remap to indices, then we resize, - # then we can acquire the new pointers. - cdef Pool tmp_mem = Pool() - keys = tmp_mem.alloc(self.size, sizeof(key_t)) - cdef key_t key - cdef void* value - cdef const Utf8Str ptr - cdef int i = 0 - cdef size_t offset - while map_iter(self._map.c_map, &i, &key, &value): - # Find array index with pointer arithmetic - offset = ((value) - self.c) - keys[offset] = key - - self._resize_at *= 2 - cdef size_t new_size = self._resize_at * sizeof(Utf8Str) - self.c = self.mem.realloc(self.c, new_size) - - self._map = PreshMap(self.size) - for i in range(self.size): - if keys[i]: - self._map.set(keys[i], &self.c[i]) + value = _allocate(self.mem, utf8_string, length) + self._map.set(key, value) + self.keys.push_back(key) + return value diff --git a/spacy/tests/stringstore/test_stringstore.py b/spacy/tests/stringstore/test_stringstore.py index e3c94e33b..be2afd04e 100644 --- a/spacy/tests/stringstore/test_stringstore.py +++ b/spacy/tests/stringstore/test_stringstore.py @@ -8,69 +8,65 @@ import pytest @pytest.mark.parametrize('text1,text2,text3', [(b'Hello', b'goodbye', b'hello')]) def test_stringstore_save_bytes(stringstore, text1, text2, text3): - i = stringstore[text1] - assert i == 1 - assert stringstore[text1] == 1 - assert stringstore[text2] != i - assert stringstore[text3] != i - assert i == 1 + key = stringstore.add(text1) + assert stringstore[text1] == key + assert stringstore[text2] != key + assert stringstore[text3] != key @pytest.mark.parametrize('text1,text2,text3', [('Hello', 'goodbye', 'hello')]) def test_stringstore_save_unicode(stringstore, text1, text2, text3): - i = stringstore[text1] - assert i == 1 - assert stringstore[text1] == 1 - assert stringstore[text2] != i - assert stringstore[text3] != i - assert i == 1 + key = stringstore.add(text1) + assert stringstore[text1] == key + assert stringstore[text2] != key + assert stringstore[text3] != key @pytest.mark.parametrize('text', [b'A']) def test_stringstore_retrieve_id(stringstore, text): - i = stringstore[text] - assert stringstore.size == 1 - assert stringstore[1] == text.decode('utf8') - with pytest.raises(IndexError): + key = stringstore.add(text) + assert len(stringstore) == 1 + assert stringstore[key] == text.decode('utf8') + with pytest.raises(KeyError): stringstore[2] @pytest.mark.parametrize('text1,text2', [(b'0123456789', b'A')]) def test_stringstore_med_string(stringstore, text1, text2): - store = stringstore[text1] + store = stringstore.add(text1) assert stringstore[store] == text1.decode('utf8') - dummy = stringstore[text2] + dummy = stringstore.add(text2) assert stringstore[text1] == store def test_stringstore_long_string(stringstore): text = "INFORMATIVE](http://www.google.com/search?as_q=RedditMonkey&hl=en&num=50&btnG=Google+Search&as_epq=&as_oq=&as_eq=&lr=&as_ft=i&as_filetype=&as_qdr=all&as_nlo=&as_nhi=&as_occt=any&as_dt=i&as_sitesearch=&as_rights=&safe=off" - store = stringstore[text] + store = stringstore.add(text) assert stringstore[store] == text @pytest.mark.parametrize('factor', [254, 255, 256]) def test_stringstore_multiply(stringstore, factor): text = 'a' * factor - store = stringstore[text] + store = stringstore.add(text) assert stringstore[store] == text def test_stringstore_massive_strings(stringstore): text = 'a' * 511 - store = stringstore[text] + store = stringstore.add(text) assert stringstore[store] == text text2 = 'z' * 512 - store = stringstore[text2] + store = stringstore.add(text2) assert stringstore[store] == text2 text3 = '1' * 513 - store = stringstore[text3] + store = stringstore.add(text3) assert stringstore[store] == text3 @pytest.mark.parametrize('text', ["qqqqq"]) def test_stringstore_to_bytes(stringstore, text): - store = stringstore[text] + store = stringstore.add(text) serialized = stringstore.to_bytes() new_stringstore = StringStore().from_bytes(serialized) assert new_stringstore[store] == text diff --git a/spacy/typedefs.pxd b/spacy/typedefs.pxd index bd863d247..bd5b38958 100644 --- a/spacy/typedefs.pxd +++ b/spacy/typedefs.pxd @@ -4,7 +4,7 @@ from libc.stdint cimport uint8_t ctypedef uint64_t hash_t ctypedef char* utf8_t -ctypedef int32_t attr_t +ctypedef uint64_t attr_t ctypedef uint64_t flags_t ctypedef uint16_t len_t ctypedef uint16_t tag_t diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 52fd0b35f..8f03470b0 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -172,7 +172,7 @@ cdef class Vocab: for attr, func in self.lex_attr_getters.items(): value = func(string) if isinstance(value, unicode): - value = self.strings[value] + value = self.strings.add(value) if attr == PROB: lex.prob = value elif value is not None: @@ -227,7 +227,7 @@ cdef class Vocab: """ cdef attr_t orth if type(id_or_string) == unicode: - orth = self.strings[id_or_string] + orth = self.strings.add(id_or_string) else: orth = id_or_string return Lexeme(self, orth) @@ -291,7 +291,7 @@ cdef class Vocab: with (path / 'vocab' / 'strings.json').open('r', encoding='utf8') as file_: strings_list = ujson.load(file_) for string in strings_list: - self.strings[string] + self.strings.add(string) self.load_lexemes(path / 'lexemes.bin') def to_bytes(self, **exclude):