From 446a3ecf34f3b0c139a326d7471b9b1b5346bcb0 Mon Sep 17 00:00:00 2001 From: Madeesh Kannan Date: Thu, 6 Oct 2022 10:51:06 +0200 Subject: [PATCH] `StringStore` refactoring (#11344) * `strings`: Remove unused `hash32_utf8` function * `strings`: Make `hash_utf8` and `decode_Utf8Str` private * `strings`: Reorganize private functions * 'strings': Raise error when non-string/-int types are passed to functions that don't accept them * `strings`: Add `items()` method, add type hints, remove unused methods, restrict inputs to specific types, reorganize methods * `Morphology`: Use `StringStore.items()` to enumerate features when pickling * `test_stringstore`: Update pre-Python 3 tests * Update `StringStore` docs * Fix `get_string_id` imports * Replace redundant test with tests for type checking * Rename `_retrieve_interned_str`, remove `.get` default arg * Add `get_string_id` to `strings.pyi` Remove `mypy` ignore directives from imports of the above * `strings.pyi`: Replace functions that consume `Union`-typed params with overloads * `strings.pyi`: Revert some function signatures * Update `SYMBOLS_BY_INT` lookups and error codes post-merge * Revert clobbered change introduced in a previous merge * Remove unnecessary type hint * Invert tuple order in `StringStore.items()` * Add test for `StringStore.items()` * Revert "`Morphology`: Use `StringStore.items()` to enumerate features when pickling" This reverts commit 1af9510ceb6b08cfdcfbf26df6896f26709fac0d. * Rename `keys` and `key_map` * Add `keys()` and `values()` * Add comment about the inverted key-value semantics in the API * Fix type hints * Implement `keys()`, `values()`, `items()` without generators * Fix type hints, remove unnecessary boxing * Update docs * Simplify `keys/values/items()` impl * `mypy` fix * Fix error message, doc fixes --- spacy/errors.py | 4 +- spacy/matcher/matcher.pyx | 2 +- spacy/strings.pxd | 21 +- spacy/strings.pyi | 23 +- spacy/strings.pyx | 426 +++++++++--------- spacy/tests/vocab_vectors/test_stringstore.py | 41 +- spacy/tokens/graph.pyx | 2 +- spacy/tokens/retokenizer.pyx | 2 +- website/docs/api/stringstore.md | 82 +++- 9 files changed, 339 insertions(+), 264 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 5fb59e2c5..856660106 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -252,7 +252,7 @@ class Errors(metaclass=ErrorsWithCodes): E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}") E016 = ("MultitaskObjective target should be function or one of: dep, " "tag, ent, dep_tag_offset, ent_tag.") - E017 = ("Can only add unicode or bytes. Got type: {value_type}") + E017 = ("Can only add 'str' inputs to StringStore. Got type: {value_type}") E018 = ("Can't retrieve string for hash '{hash_value}'. This usually " "refers to an issue with the `Vocab` or `StringStore`.") E019 = ("Can't create transition with unknown action ID: {action}. Action " @@ -955,6 +955,8 @@ class Errors(metaclass=ErrorsWithCodes): # v4 error strings E4000 = ("Expected a Doc as input, but got: '{type}'") + E4001 = ("Expected input to be one of the following types: ({expected_types}), " + "but got '{received_type}'") # Deprecated model shortcuts, only used in errors and warnings diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index 865e7594e..8bd05f25f 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -22,7 +22,7 @@ from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA, MORPH from ..schemas import validate_token_pattern from ..errors import Errors, MatchPatternError, Warnings -from ..strings import get_string_id +from ..strings cimport get_string_id from ..attrs import IDS diff --git a/spacy/strings.pxd b/spacy/strings.pxd index 5f03a9a28..0c1a30fe3 100644 --- a/spacy/strings.pxd +++ b/spacy/strings.pxd @@ -1,4 +1,4 @@ -from libc.stdint cimport int64_t +from libc.stdint cimport int64_t, uint32_t from libcpp.vector cimport vector from libcpp.set cimport set from cymem.cymem cimport Pool @@ -7,13 +7,6 @@ from murmurhash.mrmr cimport hash64 from .typedefs cimport attr_t, hash_t - -cpdef hash_t hash_string(str string) except 0 -cdef hash_t hash_utf8(char* utf8_string, int length) nogil - -cdef str decode_Utf8Str(const Utf8Str* string) - - ctypedef union Utf8Str: unsigned char[8] s unsigned char* p @@ -21,9 +14,13 @@ ctypedef union Utf8Str: cdef class StringStore: cdef Pool mem + cdef vector[hash_t] _keys + cdef PreshMap _map - cdef vector[hash_t] keys - cdef public PreshMap _map + cdef hash_t _intern_str(self, str string) + cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except * + cdef str _decode_str_repr(self, const Utf8Str* string) - cdef const Utf8Str* intern_unicode(self, str py_string) - cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash) + +cpdef hash_t hash_string(object string) except -1 +cpdef hash_t get_string_id(object string_or_hash) except -1 diff --git a/spacy/strings.pyi b/spacy/strings.pyi index b29389b9a..d9509ff57 100644 --- a/spacy/strings.pyi +++ b/spacy/strings.pyi @@ -1,21 +1,20 @@ -from typing import Optional, Iterable, Iterator, Union, Any, overload +from typing import List, Optional, Iterable, Iterator, Union, Any, Tuple, overload from pathlib import Path -def get_string_id(key: Union[str, int]) -> int: ... - class StringStore: - def __init__( - self, strings: Optional[Iterable[str]] = ..., freeze: bool = ... - ) -> None: ... + def __init__(self, strings: Optional[Iterable[str]]) -> None: ... @overload - def __getitem__(self, string_or_id: Union[bytes, str]) -> int: ... + def __getitem__(self, string_or_hash: str) -> int: ... @overload - def __getitem__(self, string_or_id: int) -> str: ... - def as_int(self, key: Union[bytes, str, int]) -> int: ... - def as_string(self, key: Union[bytes, str, int]) -> str: ... + def __getitem__(self, string_or_hash: int) -> str: ... + def as_int(self, string_or_hash: Union[str, int]) -> int: ... + def as_string(self, string_or_hash: Union[str, int]) -> str: ... def add(self, string: str) -> int: ... + def items(self) -> List[Tuple[str, int]]: ... + def keys(self) -> List[str]: ... + def values(self) -> List[int]: ... def __len__(self) -> int: ... - def __contains__(self, string: str) -> bool: ... + def __contains__(self, string_or_hash: Union[str, int]) -> bool: ... def __iter__(self) -> Iterator[str]: ... def __reduce__(self) -> Any: ... def to_disk(self, path: Union[str, Path]) -> None: ... @@ -23,3 +22,5 @@ class StringStore: def to_bytes(self, **kwargs: Any) -> bytes: ... def from_bytes(self, bytes_data: bytes, **kwargs: Any) -> StringStore: ... def _reset_and_load(self, strings: Iterable[str]) -> None: ... + +def get_string_id(string_or_hash: Union[str, int]) -> int: ... diff --git a/spacy/strings.pyx b/spacy/strings.pyx index e86682733..5a037eb9a 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -1,9 +1,10 @@ # cython: infer_types=True +from typing import Optional, Union, Iterable, Tuple, Callable, Any, List, Iterator cimport cython from libc.string cimport memcpy from libcpp.set cimport set from libc.stdint cimport uint32_t -from murmurhash.mrmr cimport hash64, hash32 +from murmurhash.mrmr cimport hash64 import srsly @@ -14,105 +15,13 @@ from .symbols import NAMES as SYMBOLS_BY_INT from .errors import Errors from . import util -# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)` -cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash): - try: - out_hash[0] = key - return True - except: - return False - -def get_string_id(key): - """Get a string ID, handling the reserved symbols correctly. If the key is - already an ID, return it. - - This function optimises for convenience over performance, so shouldn't be - used in tight loops. - """ - cdef hash_t str_hash - if isinstance(key, str): - if len(key) == 0: - return 0 - - symbol = SYMBOLS_BY_STR.get(key, None) - if symbol is not None: - return symbol - else: - chars = key.encode("utf8") - return hash_utf8(chars, len(chars)) - elif _try_coerce_to_hash(key, &str_hash): - # Coerce the integral key to the expected primitive hash type. - # This ensures that custom/overloaded "primitive" data types - # such as those implemented by numpy are not inadvertently used - # downsteam (as these are internally implemented as custom PyObjects - # whose comparison operators can incur a significant overhead). - return str_hash - else: - # TODO: Raise an error instead - return key - - -cpdef hash_t hash_string(str string) except 0: - chars = string.encode("utf8") - return hash_utf8(chars, len(chars)) - - -cdef hash_t hash_utf8(char* utf8_string, int length) nogil: - return hash64(utf8_string, length, 1) - - -cdef uint32_t hash32_utf8(char* utf8_string, int length) nogil: - return hash32(utf8_string, length, 1) - - -cdef str decode_Utf8Str(const Utf8Str* string): - cdef int i, length - if string.s[0] < sizeof(string.s) and string.s[0] != 0: - return string.s[1:string.s[0]+1].decode("utf8") - elif string.p[0] < 255: - return string.p[1:string.p[0]+1].decode("utf8") - else: - i = 0 - length = 0 - while string.p[i] == 255: - i += 1 - length += 255 - length += string.p[i] - i += 1 - return string.p[i:length + i].decode("utf8") - - -cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *: - cdef int n_length_bytes - cdef int i - cdef Utf8Str* string = mem.alloc(1, sizeof(Utf8Str)) - cdef uint32_t ulength = length - if length < sizeof(string.s): - string.s[0] = length - memcpy(&string.s[1], chars, length) - return string - elif length < 255: - string.p = mem.alloc(length + 1, sizeof(unsigned char)) - string.p[0] = length - memcpy(&string.p[1], chars, length) - return string - else: - i = 0 - n_length_bytes = (length // 255) + 1 - string.p = mem.alloc(length + n_length_bytes, sizeof(unsigned char)) - for i in range(n_length_bytes-1): - string.p[i] = 255 - string.p[n_length_bytes-1] = length % 255 - memcpy(&string.p[n_length_bytes], chars, length) - return string - cdef class StringStore: - """Look up strings by 64-bit hashes. + """Look up strings by 64-bit hashes. Implicitly handles reserved symbols. DOCS: https://spacy.io/api/stringstore """ - def __init__(self, strings=None, freeze=False): + def __init__(self, strings: Optional[Iterable[str]] = None): """Create the StringStore. strings (iterable): A sequence of unicode strings to add to the store. @@ -123,128 +32,127 @@ cdef class StringStore: for string in strings: self.add(string) - def __getitem__(self, object string_or_id): - """Retrieve a string from a given hash, or vice versa. + def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]: + """Retrieve a string from a given hash. If a string + is passed as the input, add it to the store and return + its hash. - string_or_id (bytes, str or uint64): The value to encode. - Returns (str / uint64): The value to be retrieved. + string_or_hash (int / str): The hash value to lookup or the string to store. + RETURNS (str / int): The stored string or the hash of the newly added string. """ - cdef hash_t str_hash - cdef Utf8Str* utf8str = NULL - - if isinstance(string_or_id, str): - if len(string_or_id) == 0: - return 0 - - # Return early if the string is found in the symbols LUT. - symbol = SYMBOLS_BY_STR.get(string_or_id, None) - if symbol is not None: - return symbol - else: - return hash_string(string_or_id) - elif isinstance(string_or_id, bytes): - return hash_utf8(string_or_id, len(string_or_id)) - elif _try_coerce_to_hash(string_or_id, &str_hash): - if str_hash == 0: - return "" - elif str_hash in SYMBOLS_BY_INT: - return SYMBOLS_BY_INT[str_hash] - else: - utf8str = self._map.get(str_hash) + if isinstance(string_or_hash, str): + return self.add(string_or_hash) else: - # TODO: Raise an error instead - utf8str = self._map.get(string_or_id) + return self._get_interned_str(string_or_hash) - if utf8str is NULL: - raise KeyError(Errors.E018.format(hash_value=string_or_id)) - else: - return decode_Utf8Str(utf8str) + def __contains__(self, string_or_hash: Union[str, int]) -> bool: + """Check whether a string or a hash is in the store. - def as_int(self, key): - """If key is an int, return it; otherwise, get the int value.""" - if not isinstance(key, str): - return key - else: - return self[key] - - def as_string(self, key): - """If key is a string, return it; otherwise, get the string value.""" - if isinstance(key, str): - return key - else: - return self[key] - - def add(self, string): - """Add a string to the StringStore. - - string (str): The string to add. - RETURNS (uint64): The string's hash value. - """ - cdef hash_t str_hash - if isinstance(string, str): - if string in SYMBOLS_BY_STR: - return SYMBOLS_BY_STR[string] - - string = string.encode("utf8") - str_hash = hash_utf8(string, len(string)) - self._intern_utf8(string, len(string), &str_hash) - elif isinstance(string, bytes): - if string in SYMBOLS_BY_STR: - return SYMBOLS_BY_STR[string] - str_hash = hash_utf8(string, len(string)) - self._intern_utf8(string, len(string), &str_hash) - else: - raise TypeError(Errors.E017.format(value_type=type(string))) - return str_hash - - def __len__(self): - """The number of strings in the store. - - RETURNS (int): The number of strings in the store. - """ - return self.keys.size() - - def __contains__(self, string_or_id not None): - """Check whether a string or ID is in the store. - - string_or_id (str or int): The string to check. + string (str / int): The string/hash to check. RETURNS (bool): Whether the store contains the string. """ - cdef hash_t str_hash - if isinstance(string_or_id, str): - if len(string_or_id) == 0: - return True - elif string_or_id in SYMBOLS_BY_STR: - return True - str_hash = hash_string(string_or_id) - elif _try_coerce_to_hash(string_or_id, &str_hash): - pass - else: - # TODO: Raise an error instead - return self._map.get(string_or_id) is not NULL - + cdef hash_t str_hash = get_string_id(string_or_hash) if str_hash in SYMBOLS_BY_INT: return True else: return self._map.get(str_hash) is not NULL - def __iter__(self): - """Iterate over the strings in the store, in order. + def __iter__(self) -> Iterator[str]: + """Iterate over the strings in the store in insertion order. - YIELDS (str): A string in the store. + RETURNS: An iterable collection of strings. """ - cdef int i - cdef hash_t key - for i in range(self.keys.size()): - key = self.keys[i] - utf8str = self._map.get(key) - yield decode_Utf8Str(utf8str) - # TODO: Iterate OOV here? + return iter(self.keys()) def __reduce__(self): strings = list(self) return (StringStore, (strings,), None, None, None) + def __len__(self) -> int: + """The number of strings in the store. + + RETURNS (int): The number of strings in the store. + """ + return self._keys.size() + + def add(self, string: str) -> int: + """Add a string to the StringStore. + + string (str): The string to add. + RETURNS (uint64): The string's hash value. + """ + if not isinstance(string, str): + raise TypeError(Errors.E017.format(value_type=type(string))) + + if string in SYMBOLS_BY_STR: + return SYMBOLS_BY_STR[string] + else: + return self._intern_str(string) + + def as_int(self, string_or_hash: Union[str, int]) -> str: + """If a hash value is passed as the input, return it as-is. If the input + is a string, return its corresponding hash. + + string_or_hash (str / int): The string to hash or a hash value. + RETURNS (int): The hash of the string or the input hash value. + """ + if isinstance(string_or_hash, int): + return string_or_hash + else: + return get_string_id(string_or_hash) + + def as_string(self, string_or_hash: Union[str, int]) -> str: + """If a string is passed as the input, return it as-is. If the input + is a hash value, return its corresponding string. + + string_or_hash (str / int): The hash value to lookup or a string. + RETURNS (str): The stored string or the input string. + """ + if isinstance(string_or_hash, str): + return string_or_hash + else: + return self._get_interned_str(string_or_hash) + + def items(self) -> List[Tuple[str, int]]: + """Iterate over the stored strings and their hashes in insertion order. + + RETURNS: A list of string-hash pairs. + """ + # Even though we internally store the hashes as keys and the strings as + # values, we invert the order in the public API to keep it consistent with + # the implementation of the `__iter__` method (where we wish to iterate over + # the strings in the store). + cdef int i + pairs = [None] * self._keys.size() + for i in range(self._keys.size()): + str_hash = self._keys[i] + utf8str = self._map.get(str_hash) + pairs[i] = (self._decode_str_repr(utf8str), str_hash) + return pairs + + def keys(self) -> List[str]: + """Iterate over the stored strings in insertion order. + + RETURNS: A list of strings. + """ + cdef int i + strings = [None] * self._keys.size() + for i in range(self._keys.size()): + utf8str = self._map.get(self._keys[i]) + strings[i] = self._decode_str_repr(utf8str) + return strings + + def values(self) -> List[int]: + """Iterate over the stored strings hashes in insertion order. + + RETURNS: A list of string hashs. + """ + cdef int i + hashes = [None] * self._keys.size() + for i in range(self._keys.size()): + hashes[i] = self._keys[i] + return hashes + def to_disk(self, path): """Save the current state to a directory. @@ -294,24 +202,122 @@ cdef class StringStore: def _reset_and_load(self, strings): self.mem = Pool() self._map = PreshMap() - self.keys.clear() + self._keys.clear() for string in strings: self.add(string) - cdef const Utf8Str* intern_unicode(self, str py_string): - # 0 means missing, but we don't bother offsetting the index. - cdef bytes byte_string = py_string.encode("utf8") - return self._intern_utf8(byte_string, len(byte_string), NULL) + def _get_interned_str(self, hash_value: int) -> str: + cdef hash_t str_hash + if not _try_coerce_to_hash(hash_value, &str_hash): + raise TypeError(Errors.E4001.format(expected_types="'int'", received_type=type(hash_value))) - @cython.final - cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash): + # Handle reserved symbols and empty strings correctly. + if str_hash == 0: + return "" + + symbol = SYMBOLS_BY_INT.get(str_hash) + if symbol is not None: + return symbol + + utf8str = self._map.get(str_hash) + if utf8str is NULL: + raise KeyError(Errors.E018.format(hash_value=str_hash)) + else: + return self._decode_str_repr(utf8str) + + cdef hash_t _intern_str(self, str string): # TODO: This function's API/behaviour is an unholy mess... # 0 means missing, but we don't bother offsetting the index. - cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else hash_utf8(utf8_string, length) + chars = string.encode('utf-8') + cdef hash_t key = hash64(chars, len(chars), 1) cdef Utf8Str* value = self._map.get(key) if value is not NULL: - return value - value = _allocate(self.mem, utf8_string, length) + return key + + value = self._allocate_str_repr(chars, len(chars)) self._map.set(key, value) - self.keys.push_back(key) - return value + self._keys.push_back(key) + return key + + cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *: + cdef int n_length_bytes + cdef int i + cdef Utf8Str* string = self.mem.alloc(1, sizeof(Utf8Str)) + cdef uint32_t ulength = length + if length < sizeof(string.s): + string.s[0] = length + memcpy(&string.s[1], chars, length) + return string + elif length < 255: + string.p = self.mem.alloc(length + 1, sizeof(unsigned char)) + string.p[0] = length + memcpy(&string.p[1], chars, length) + return string + else: + i = 0 + n_length_bytes = (length // 255) + 1 + string.p = self.mem.alloc(length + n_length_bytes, sizeof(unsigned char)) + for i in range(n_length_bytes-1): + string.p[i] = 255 + string.p[n_length_bytes-1] = length % 255 + memcpy(&string.p[n_length_bytes], chars, length) + return string + + cdef str _decode_str_repr(self, const Utf8Str* string): + cdef int i, length + if string.s[0] < sizeof(string.s) and string.s[0] != 0: + return string.s[1:string.s[0]+1].decode('utf-8') + elif string.p[0] < 255: + return string.p[1:string.p[0]+1].decode('utf-8') + else: + i = 0 + length = 0 + while string.p[i] == 255: + i += 1 + length += 255 + length += string.p[i] + i += 1 + return string.p[i:length + i].decode('utf-8') + + +cpdef hash_t hash_string(object string) except -1: + if not isinstance(string, str): + raise TypeError(Errors.E4001.format(expected_types="'str'", received_type=type(string))) + + # Handle reserved symbols and empty strings correctly. + if len(string) == 0: + return 0 + + symbol = SYMBOLS_BY_STR.get(string) + if symbol is not None: + return symbol + + chars = string.encode('utf-8') + return hash64(chars, len(chars), 1) + + +cpdef hash_t get_string_id(object string_or_hash) except -1: + cdef hash_t str_hash + + try: + return hash_string(string_or_hash) + except: + if _try_coerce_to_hash(string_or_hash, &str_hash): + # Coerce the integral key to the expected primitive hash type. + # This ensures that custom/overloaded "primitive" data types + # such as those implemented by numpy are not inadvertently used + # downsteam (as these are internally implemented as custom PyObjects + # whose comparison operators can incur a significant overhead). + return str_hash + else: + raise TypeError(Errors.E4001.format(expected_types="'str','int'", received_type=type(string_or_hash))) + + +# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)` +cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash): + try: + out_hash[0] = key + return True + except: + return False + diff --git a/spacy/tests/vocab_vectors/test_stringstore.py b/spacy/tests/vocab_vectors/test_stringstore.py index a0f8016af..f86c0f10d 100644 --- a/spacy/tests/vocab_vectors/test_stringstore.py +++ b/spacy/tests/vocab_vectors/test_stringstore.py @@ -24,6 +24,14 @@ def test_stringstore_from_api_docs(stringstore): stringstore.add("orange") all_strings = [s for s in stringstore] assert all_strings == ["apple", "orange"] + assert all_strings == list(stringstore.keys()) + all_strings_and_hashes = list(stringstore.items()) + assert all_strings_and_hashes == [ + ("apple", 8566208034543834098), + ("orange", 2208928596161743350), + ] + all_hashes = list(stringstore.values()) + assert all_hashes == [8566208034543834098, 2208928596161743350] banana_hash = stringstore.add("banana") assert len(stringstore) == 3 assert banana_hash == 2525716904149915114 @@ -31,12 +39,25 @@ def test_stringstore_from_api_docs(stringstore): assert stringstore["banana"] == banana_hash -@pytest.mark.parametrize("text1,text2,text3", [(b"Hello", b"goodbye", b"hello")]) -def test_stringstore_save_bytes(stringstore, text1, text2, text3): - key = stringstore.add(text1) - assert stringstore[text1] == key - assert stringstore[text2] != key - assert stringstore[text3] != key +@pytest.mark.parametrize( + "val_bytes,val_float,val_list,val_text,val_hash", + [(b"Hello", 1.1, ["abc"], "apple", 8566208034543834098)], +) +def test_stringstore_type_checking( + stringstore, val_bytes, val_float, val_list, val_text, val_hash +): + with pytest.raises(TypeError): + assert stringstore[val_bytes] + + with pytest.raises(TypeError): + stringstore.add(val_float) + + with pytest.raises(TypeError): + assert val_list not in stringstore + + key = stringstore.add(val_text) + assert val_hash == key + assert stringstore[val_hash] == val_text @pytest.mark.parametrize("text1,text2,text3", [("Hello", "goodbye", "hello")]) @@ -47,19 +68,19 @@ def test_stringstore_save_unicode(stringstore, text1, text2, text3): assert stringstore[text3] != key -@pytest.mark.parametrize("text", [b"A"]) +@pytest.mark.parametrize("text", ["A"]) def test_stringstore_retrieve_id(stringstore, text): key = stringstore.add(text) assert len(stringstore) == 1 - assert stringstore[key] == text.decode("utf8") + assert stringstore[key] == text with pytest.raises(KeyError): stringstore[20000] -@pytest.mark.parametrize("text1,text2", [(b"0123456789", b"A")]) +@pytest.mark.parametrize("text1,text2", [("0123456789", "A")]) def test_stringstore_med_string(stringstore, text1, text2): store = stringstore.add(text1) - assert stringstore[store] == text1.decode("utf8") + assert stringstore[store] == text1 stringstore.add(text2) assert stringstore[text1] == store diff --git a/spacy/tokens/graph.pyx b/spacy/tokens/graph.pyx index adc4d23c8..0ae0d94c7 100644 --- a/spacy/tokens/graph.pyx +++ b/spacy/tokens/graph.pyx @@ -12,7 +12,7 @@ from murmurhash.mrmr cimport hash64 from .. import Errors from ..typedefs cimport hash_t -from ..strings import get_string_id +from ..strings cimport get_string_id from ..structs cimport EdgeC, GraphC from .token import Token diff --git a/spacy/tokens/retokenizer.pyx b/spacy/tokens/retokenizer.pyx index 43e6d4aa7..29143bed3 100644 --- a/spacy/tokens/retokenizer.pyx +++ b/spacy/tokens/retokenizer.pyx @@ -18,7 +18,7 @@ from .underscore import is_writable_attr from ..attrs import intify_attrs from ..util import SimpleFrozenDict from ..errors import Errors -from ..strings import get_string_id +from ..strings cimport get_string_id cdef class Retokenizer: diff --git a/website/docs/api/stringstore.md b/website/docs/api/stringstore.md index cd414b1f0..b509659ef 100644 --- a/website/docs/api/stringstore.md +++ b/website/docs/api/stringstore.md @@ -40,7 +40,8 @@ Get the number of strings in the store. ## StringStore.\_\_getitem\_\_ {#getitem tag="method"} -Retrieve a string from a given hash, or vice versa. +Retrieve a string from a given hash. If a string is passed as the input, add it +to the store and return its hash. > #### Example > @@ -51,14 +52,14 @@ Retrieve a string from a given hash, or vice versa. > assert stringstore[apple_hash] == "apple" > ``` -| Name | Description | -| -------------- | ----------------------------------------------- | -| `string_or_id` | The value to encode. ~~Union[bytes, str, int]~~ | -| **RETURNS** | The value to be retrieved. ~~Union[str, int]~~ | +| Name | Description | +| ---------------- | ---------------------------------------------------------------------------- | +| `string_or_hash` | The hash value to lookup or the string to store. ~~Union[str, int]~~ | +| **RETURNS** | The stored string or the hash of the newly added string. ~~Union[str, int]~~ | ## StringStore.\_\_contains\_\_ {#contains tag="method"} -Check whether a string is in the store. +Check whether a string or a hash is in the store. > #### Example > @@ -68,15 +69,14 @@ Check whether a string is in the store. > assert not "cherry" in stringstore > ``` -| Name | Description | -| ----------- | ----------------------------------------------- | -| `string` | The string to check. ~~str~~ | -| **RETURNS** | Whether the store contains the string. ~~bool~~ | +| Name | Description | +| ---------------- | ------------------------------------------------------- | +| `string_or_hash` | The string or hash to check. ~~Union[str, int]~~ | +| **RETURNS** | Whether the store contains the string or hash. ~~bool~~ | ## StringStore.\_\_iter\_\_ {#iter tag="method"} -Iterate over the strings in the store, in order. Note that a newly initialized -store will always include an empty string `""` at position `0`. +Iterate over the stored strings in insertion order. > #### Example > @@ -86,11 +86,59 @@ store will always include an empty string `""` at position `0`. > assert all_strings == ["apple", "orange"] > ``` -| Name | Description | -| ---------- | ------------------------------ | -| **YIELDS** | A string in the store. ~~str~~ | +| Name | Description | +| ----------- | ------------------------------ | +| **RETURNS** | A string in the store. ~~str~~ | -## StringStore.add {#add tag="method" new="2"} +## StringStore.items {#iter tag="method" new="4"} + +Iterate over the stored string-hash pairs in insertion order. + +> #### Example +> +> ```python +> stringstore = StringStore(["apple", "orange"]) +> all_strings_and_hashes = stringstore.items() +> assert all_strings_and_hashes == [("apple", 8566208034543834098), ("orange", 2208928596161743350)] +> ``` + +| Name | Description | +| ----------- | ------------------------------------------------------ | +| **RETURNS** | A list of string-hash pairs. ~~List[Tuple[str, int]]~~ | + +## StringStore.keys {#iter tag="method" new="4"} + +Iterate over the stored strings in insertion order. + +> #### Example +> +> ```python +> stringstore = StringStore(["apple", "orange"]) +> all_strings = stringstore.keys() +> assert all_strings == ["apple", "orange"] +> ``` + +| Name | Description | +| ----------- | -------------------------------- | +| **RETURNS** | A list of strings. ~~List[str]~~ | + +## StringStore.values {#iter tag="method" new="4"} + +Iterate over the stored string hashes in insertion order. + +> #### Example +> +> ```python +> stringstore = StringStore(["apple", "orange"]) +> all_hashes = stringstore.values() +> assert all_hashes == [8566208034543834098, 2208928596161743350] +> ``` + +| Name | Description | +| ----------- | -------------------------------------- | +| **RETURNS** | A list of string hashes. ~~List[int]~~ | + +## StringStore.add {#add tag="method"} Add a string to the `StringStore`. @@ -110,7 +158,7 @@ Add a string to the `StringStore`. | `string` | The string to add. ~~str~~ | | **RETURNS** | The string's hash value. ~~int~~ | -## StringStore.to_disk {#to_disk tag="method" new="2"} +## StringStore.to_disk {#to_disk tag="method"} Save the current state to a directory.