mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	StringStore refactoring (#11344)
				
					
				
			* `strings`: Remove unused `hash32_utf8` function
* `strings`: Make `hash_utf8` and `decode_Utf8Str` private
* `strings`: Reorganize private functions
* 'strings': Raise error when non-string/-int types are passed to functions that don't accept them
* `strings`: Add `items()` method, add type hints, remove unused methods, restrict inputs to specific types, reorganize methods
* `Morphology`: Use `StringStore.items()` to enumerate features when pickling
* `test_stringstore`: Update pre-Python 3 tests
* Update `StringStore` docs
* Fix `get_string_id` imports
* Replace redundant test with tests for type checking
* Rename `_retrieve_interned_str`, remove `.get` default arg
* Add `get_string_id` to `strings.pyi`
Remove `mypy` ignore directives from imports of the above
* `strings.pyi`: Replace functions that consume `Union`-typed params with overloads
* `strings.pyi`: Revert some function signatures
* Update `SYMBOLS_BY_INT` lookups and error codes post-merge
* Revert clobbered change introduced in a previous merge
* Remove unnecessary type hint
* Invert tuple order in `StringStore.items()`
* Add test for `StringStore.items()`
* Revert "`Morphology`: Use `StringStore.items()` to enumerate features when pickling"
This reverts commit 1af9510ceb.
* Rename `keys` and `key_map`
* Add `keys()` and `values()`
* Add comment about the inverted key-value semantics in the API
* Fix type hints
* Implement `keys()`, `values()`, `items()` without generators
* Fix type hints, remove unnecessary boxing
* Update docs
* Simplify `keys/values/items()` impl
* `mypy` fix
* Fix error message, doc fixes
			
			
This commit is contained in:
		
							parent
							
								
									c6704f368c
								
							
						
					
					
						commit
						446a3ecf34
					
				|  | @ -252,7 +252,7 @@ class Errors(metaclass=ErrorsWithCodes): | ||||||
|     E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}") |     E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}") | ||||||
|     E016 = ("MultitaskObjective target should be function or one of: dep, " |     E016 = ("MultitaskObjective target should be function or one of: dep, " | ||||||
|             "tag, ent, dep_tag_offset, ent_tag.") |             "tag, ent, dep_tag_offset, ent_tag.") | ||||||
|     E017 = ("Can only add unicode or bytes. Got type: {value_type}") |     E017 = ("Can only add 'str' inputs to StringStore. Got type: {value_type}") | ||||||
|     E018 = ("Can't retrieve string for hash '{hash_value}'. This usually " |     E018 = ("Can't retrieve string for hash '{hash_value}'. This usually " | ||||||
|             "refers to an issue with the `Vocab` or `StringStore`.") |             "refers to an issue with the `Vocab` or `StringStore`.") | ||||||
|     E019 = ("Can't create transition with unknown action ID: {action}. Action " |     E019 = ("Can't create transition with unknown action ID: {action}. Action " | ||||||
|  | @ -955,6 +955,8 @@ class Errors(metaclass=ErrorsWithCodes): | ||||||
| 
 | 
 | ||||||
|     # v4 error strings |     # v4 error strings | ||||||
|     E4000 = ("Expected a Doc as input, but got: '{type}'") |     E4000 = ("Expected a Doc as input, but got: '{type}'") | ||||||
|  |     E4001 = ("Expected input to be one of the following types: ({expected_types}), " | ||||||
|  |              "but got '{received_type}'") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| # Deprecated model shortcuts, only used in errors and warnings | # Deprecated model shortcuts, only used in errors and warnings | ||||||
|  |  | ||||||
|  | @ -22,7 +22,7 @@ from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA, MORPH | ||||||
| 
 | 
 | ||||||
| from ..schemas import validate_token_pattern | from ..schemas import validate_token_pattern | ||||||
| from ..errors import Errors, MatchPatternError, Warnings | from ..errors import Errors, MatchPatternError, Warnings | ||||||
| from ..strings import get_string_id | from ..strings cimport get_string_id | ||||||
| from ..attrs import IDS | from ..attrs import IDS | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,4 @@ | ||||||
| from libc.stdint cimport int64_t | from libc.stdint cimport int64_t, uint32_t | ||||||
| from libcpp.vector cimport vector | from libcpp.vector cimport vector | ||||||
| from libcpp.set cimport set | from libcpp.set cimport set | ||||||
| from cymem.cymem cimport Pool | from cymem.cymem cimport Pool | ||||||
|  | @ -7,13 +7,6 @@ from murmurhash.mrmr cimport hash64 | ||||||
| 
 | 
 | ||||||
| from .typedefs cimport attr_t, hash_t | from .typedefs cimport attr_t, hash_t | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| cpdef hash_t hash_string(str string) except 0 |  | ||||||
| cdef hash_t hash_utf8(char* utf8_string, int length) nogil |  | ||||||
| 
 |  | ||||||
| cdef str decode_Utf8Str(const Utf8Str* string) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| ctypedef union Utf8Str: | ctypedef union Utf8Str: | ||||||
|     unsigned char[8] s |     unsigned char[8] s | ||||||
|     unsigned char* p |     unsigned char* p | ||||||
|  | @ -21,9 +14,13 @@ ctypedef union Utf8Str: | ||||||
| 
 | 
 | ||||||
| cdef class StringStore: | cdef class StringStore: | ||||||
|     cdef Pool mem |     cdef Pool mem | ||||||
|  |     cdef vector[hash_t] _keys | ||||||
|  |     cdef PreshMap _map | ||||||
| 
 | 
 | ||||||
|     cdef vector[hash_t] keys |     cdef hash_t _intern_str(self, str string) | ||||||
|     cdef public PreshMap _map |     cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except * | ||||||
|  |     cdef str _decode_str_repr(self, const Utf8Str* string) | ||||||
| 
 | 
 | ||||||
|     cdef const Utf8Str* intern_unicode(self, str py_string) | 
 | ||||||
|     cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash) | cpdef hash_t hash_string(object string) except -1 | ||||||
|  | cpdef hash_t get_string_id(object string_or_hash) except -1 | ||||||
|  |  | ||||||
|  | @ -1,21 +1,20 @@ | ||||||
| from typing import Optional, Iterable, Iterator, Union, Any, overload | from typing import List, Optional, Iterable, Iterator, Union, Any, Tuple, overload | ||||||
| from pathlib import Path | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| def get_string_id(key: Union[str, int]) -> int: ... |  | ||||||
| 
 |  | ||||||
| class StringStore: | class StringStore: | ||||||
|     def __init__( |     def __init__(self, strings: Optional[Iterable[str]]) -> None: ... | ||||||
|         self, strings: Optional[Iterable[str]] = ..., freeze: bool = ... |  | ||||||
|     ) -> None: ... |  | ||||||
|     @overload |     @overload | ||||||
|     def __getitem__(self, string_or_id: Union[bytes, str]) -> int: ... |     def __getitem__(self, string_or_hash: str) -> int: ... | ||||||
|     @overload |     @overload | ||||||
|     def __getitem__(self, string_or_id: int) -> str: ... |     def __getitem__(self, string_or_hash: int) -> str: ... | ||||||
|     def as_int(self, key: Union[bytes, str, int]) -> int: ... |     def as_int(self, string_or_hash: Union[str, int]) -> int: ... | ||||||
|     def as_string(self, key: Union[bytes, str, int]) -> str: ... |     def as_string(self, string_or_hash: Union[str, int]) -> str: ... | ||||||
|     def add(self, string: str) -> int: ... |     def add(self, string: str) -> int: ... | ||||||
|  |     def items(self) -> List[Tuple[str, int]]: ... | ||||||
|  |     def keys(self) -> List[str]: ... | ||||||
|  |     def values(self) -> List[int]: ... | ||||||
|     def __len__(self) -> int: ... |     def __len__(self) -> int: ... | ||||||
|     def __contains__(self, string: str) -> bool: ... |     def __contains__(self, string_or_hash: Union[str, int]) -> bool: ... | ||||||
|     def __iter__(self) -> Iterator[str]: ... |     def __iter__(self) -> Iterator[str]: ... | ||||||
|     def __reduce__(self) -> Any: ... |     def __reduce__(self) -> Any: ... | ||||||
|     def to_disk(self, path: Union[str, Path]) -> None: ... |     def to_disk(self, path: Union[str, Path]) -> None: ... | ||||||
|  | @ -23,3 +22,5 @@ class StringStore: | ||||||
|     def to_bytes(self, **kwargs: Any) -> bytes: ... |     def to_bytes(self, **kwargs: Any) -> bytes: ... | ||||||
|     def from_bytes(self, bytes_data: bytes, **kwargs: Any) -> StringStore: ... |     def from_bytes(self, bytes_data: bytes, **kwargs: Any) -> StringStore: ... | ||||||
|     def _reset_and_load(self, strings: Iterable[str]) -> None: ... |     def _reset_and_load(self, strings: Iterable[str]) -> None: ... | ||||||
|  | 
 | ||||||
|  | def get_string_id(string_or_hash: Union[str, int]) -> int: ... | ||||||
|  |  | ||||||
|  | @ -1,9 +1,10 @@ | ||||||
| # cython: infer_types=True | # cython: infer_types=True | ||||||
|  | from typing import Optional, Union, Iterable, Tuple, Callable, Any, List, Iterator | ||||||
| cimport cython | cimport cython | ||||||
| from libc.string cimport memcpy | from libc.string cimport memcpy | ||||||
| from libcpp.set cimport set | from libcpp.set cimport set | ||||||
| from libc.stdint cimport uint32_t | from libc.stdint cimport uint32_t | ||||||
| from murmurhash.mrmr cimport hash64, hash32 | from murmurhash.mrmr cimport hash64 | ||||||
| 
 | 
 | ||||||
| import srsly | import srsly | ||||||
| 
 | 
 | ||||||
|  | @ -14,105 +15,13 @@ from .symbols import NAMES as SYMBOLS_BY_INT | ||||||
| from .errors import Errors | from .errors import Errors | ||||||
| from . import util | from . import util | ||||||
| 
 | 
 | ||||||
| # Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)` |  | ||||||
| cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash): |  | ||||||
|     try: |  | ||||||
|         out_hash[0] = key |  | ||||||
|         return True |  | ||||||
|     except: |  | ||||||
|         return False |  | ||||||
| 
 |  | ||||||
| def get_string_id(key): |  | ||||||
|     """Get a string ID, handling the reserved symbols correctly. If the key is |  | ||||||
|     already an ID, return it. |  | ||||||
| 
 |  | ||||||
|     This function optimises for convenience over performance, so shouldn't be |  | ||||||
|     used in tight loops. |  | ||||||
|     """ |  | ||||||
|     cdef hash_t str_hash     |  | ||||||
|     if isinstance(key, str): |  | ||||||
|         if len(key) == 0: |  | ||||||
|             return 0 |  | ||||||
| 
 |  | ||||||
|         symbol = SYMBOLS_BY_STR.get(key, None) |  | ||||||
|         if symbol is not None: |  | ||||||
|             return symbol |  | ||||||
|         else: |  | ||||||
|             chars = key.encode("utf8") |  | ||||||
|             return hash_utf8(chars, len(chars)) |  | ||||||
|     elif _try_coerce_to_hash(key, &str_hash): |  | ||||||
|         # Coerce the integral key to the expected primitive hash type. |  | ||||||
|         # This ensures that custom/overloaded "primitive" data types |  | ||||||
|         # such as those implemented by numpy are not inadvertently used  |  | ||||||
|         # downsteam (as these are internally implemented as custom PyObjects  |  | ||||||
|         # whose comparison operators can incur a significant overhead). |  | ||||||
|         return str_hash |  | ||||||
|     else: |  | ||||||
|         # TODO: Raise an error instead |  | ||||||
|         return key |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cpdef hash_t hash_string(str string) except 0: |  | ||||||
|     chars = string.encode("utf8") |  | ||||||
|     return hash_utf8(chars, len(chars)) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef hash_t hash_utf8(char* utf8_string, int length) nogil: |  | ||||||
|     return hash64(utf8_string, length, 1) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef uint32_t hash32_utf8(char* utf8_string, int length) nogil: |  | ||||||
|     return hash32(utf8_string, length, 1) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef str decode_Utf8Str(const Utf8Str* string): |  | ||||||
|     cdef int i, length |  | ||||||
|     if string.s[0] < sizeof(string.s) and string.s[0] != 0: |  | ||||||
|         return string.s[1:string.s[0]+1].decode("utf8") |  | ||||||
|     elif string.p[0] < 255: |  | ||||||
|         return string.p[1:string.p[0]+1].decode("utf8") |  | ||||||
|     else: |  | ||||||
|         i = 0 |  | ||||||
|         length = 0 |  | ||||||
|         while string.p[i] == 255: |  | ||||||
|             i += 1 |  | ||||||
|             length += 255 |  | ||||||
|         length += string.p[i] |  | ||||||
|         i += 1 |  | ||||||
|         return string.p[i:length + i].decode("utf8") |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *: |  | ||||||
|     cdef int n_length_bytes |  | ||||||
|     cdef int i |  | ||||||
|     cdef Utf8Str* string = <Utf8Str*>mem.alloc(1, sizeof(Utf8Str)) |  | ||||||
|     cdef uint32_t ulength = length |  | ||||||
|     if length < sizeof(string.s): |  | ||||||
|         string.s[0] = <unsigned char>length |  | ||||||
|         memcpy(&string.s[1], chars, length) |  | ||||||
|         return string |  | ||||||
|     elif length < 255: |  | ||||||
|         string.p = <unsigned char*>mem.alloc(length + 1, sizeof(unsigned char)) |  | ||||||
|         string.p[0] = length |  | ||||||
|         memcpy(&string.p[1], chars, length) |  | ||||||
|         return string |  | ||||||
|     else: |  | ||||||
|         i = 0 |  | ||||||
|         n_length_bytes = (length // 255) + 1 |  | ||||||
|         string.p = <unsigned char*>mem.alloc(length + n_length_bytes, sizeof(unsigned char)) |  | ||||||
|         for i in range(n_length_bytes-1): |  | ||||||
|             string.p[i] = 255 |  | ||||||
|         string.p[n_length_bytes-1] = length % 255 |  | ||||||
|         memcpy(&string.p[n_length_bytes], chars, length) |  | ||||||
|         return string |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| cdef class StringStore: | cdef class StringStore: | ||||||
|     """Look up strings by 64-bit hashes. |     """Look up strings by 64-bit hashes. Implicitly handles reserved symbols. | ||||||
| 
 | 
 | ||||||
|     DOCS: https://spacy.io/api/stringstore |     DOCS: https://spacy.io/api/stringstore | ||||||
|     """ |     """ | ||||||
|     def __init__(self, strings=None, freeze=False): |     def __init__(self, strings: Optional[Iterable[str]] = None): | ||||||
|         """Create the StringStore. |         """Create the StringStore. | ||||||
| 
 | 
 | ||||||
|         strings (iterable): A sequence of unicode strings to add to the store. |         strings (iterable): A sequence of unicode strings to add to the store. | ||||||
|  | @ -123,128 +32,127 @@ cdef class StringStore: | ||||||
|             for string in strings: |             for string in strings: | ||||||
|                 self.add(string) |                 self.add(string) | ||||||
| 
 | 
 | ||||||
|     def __getitem__(self, object string_or_id): |     def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]: | ||||||
|         """Retrieve a string from a given hash, or vice versa. |         """Retrieve a string from a given hash. If a string | ||||||
|  |         is passed as the input, add it to the store and return | ||||||
|  |         its hash. | ||||||
| 
 | 
 | ||||||
|         string_or_id (bytes, str or uint64): The value to encode. |         string_or_hash (int / str): The hash value to lookup or the string to store. | ||||||
|         Returns (str / uint64): The value to be retrieved. |         RETURNS (str / int): The stored string or the hash of the newly added string. | ||||||
|         """ |         """ | ||||||
|         cdef hash_t str_hash |         if isinstance(string_or_hash, str): | ||||||
|         cdef Utf8Str* utf8str = NULL |             return self.add(string_or_hash) | ||||||
| 
 |  | ||||||
|         if isinstance(string_or_id, str): |  | ||||||
|             if len(string_or_id) == 0: |  | ||||||
|                 return 0 |  | ||||||
| 
 |  | ||||||
|             # Return early if the string is found in the symbols LUT. |  | ||||||
|             symbol = SYMBOLS_BY_STR.get(string_or_id, None) |  | ||||||
|             if symbol is not None: |  | ||||||
|                 return symbol |  | ||||||
|             else: |  | ||||||
|                 return hash_string(string_or_id) |  | ||||||
|         elif isinstance(string_or_id, bytes): |  | ||||||
|             return hash_utf8(string_or_id, len(string_or_id)) |  | ||||||
|         elif _try_coerce_to_hash(string_or_id, &str_hash): |  | ||||||
|             if str_hash == 0: |  | ||||||
|                 return "" |  | ||||||
|             elif str_hash in SYMBOLS_BY_INT: |  | ||||||
|                 return SYMBOLS_BY_INT[str_hash] |  | ||||||
|             else: |  | ||||||
|                 utf8str = <Utf8Str*>self._map.get(str_hash) |  | ||||||
|         else: |         else: | ||||||
|             # TODO: Raise an error instead |             return self._get_interned_str(string_or_hash) | ||||||
|             utf8str = <Utf8Str*>self._map.get(string_or_id) |  | ||||||
| 
 | 
 | ||||||
|         if utf8str is NULL: |     def __contains__(self, string_or_hash: Union[str, int]) -> bool: | ||||||
|             raise KeyError(Errors.E018.format(hash_value=string_or_id)) |         """Check whether a string or a hash is in the store. | ||||||
|         else: |  | ||||||
|             return decode_Utf8Str(utf8str) |  | ||||||
| 
 | 
 | ||||||
|     def as_int(self, key): |         string (str / int): The string/hash to check. | ||||||
|         """If key is an int, return it; otherwise, get the int value.""" |  | ||||||
|         if not isinstance(key, str): |  | ||||||
|             return key |  | ||||||
|         else: |  | ||||||
|             return self[key] |  | ||||||
| 
 |  | ||||||
|     def as_string(self, key): |  | ||||||
|         """If key is a string, return it; otherwise, get the string value.""" |  | ||||||
|         if isinstance(key, str): |  | ||||||
|             return key |  | ||||||
|         else: |  | ||||||
|             return self[key] |  | ||||||
| 
 |  | ||||||
|     def add(self, string): |  | ||||||
|         """Add a string to the StringStore. |  | ||||||
| 
 |  | ||||||
|         string (str): The string to add. |  | ||||||
|         RETURNS (uint64): The string's hash value. |  | ||||||
|         """ |  | ||||||
|         cdef hash_t str_hash |  | ||||||
|         if isinstance(string, str): |  | ||||||
|             if string in SYMBOLS_BY_STR: |  | ||||||
|                 return SYMBOLS_BY_STR[string] |  | ||||||
| 
 |  | ||||||
|             string = string.encode("utf8") |  | ||||||
|             str_hash = hash_utf8(string, len(string)) |  | ||||||
|             self._intern_utf8(string, len(string), &str_hash) |  | ||||||
|         elif isinstance(string, bytes): |  | ||||||
|             if string in SYMBOLS_BY_STR: |  | ||||||
|                 return SYMBOLS_BY_STR[string] |  | ||||||
|             str_hash = hash_utf8(string, len(string)) |  | ||||||
|             self._intern_utf8(string, len(string), &str_hash) |  | ||||||
|         else: |  | ||||||
|             raise TypeError(Errors.E017.format(value_type=type(string))) |  | ||||||
|         return str_hash |  | ||||||
| 
 |  | ||||||
|     def __len__(self): |  | ||||||
|         """The number of strings in the store. |  | ||||||
| 
 |  | ||||||
|         RETURNS (int): The number of strings in the store. |  | ||||||
|         """ |  | ||||||
|         return self.keys.size() |  | ||||||
| 
 |  | ||||||
|     def __contains__(self, string_or_id not None): |  | ||||||
|         """Check whether a string or ID is in the store. |  | ||||||
| 
 |  | ||||||
|         string_or_id (str or int): The string to check. |  | ||||||
|         RETURNS (bool): Whether the store contains the string. |         RETURNS (bool): Whether the store contains the string. | ||||||
|         """ |         """ | ||||||
|         cdef hash_t str_hash |         cdef hash_t str_hash = get_string_id(string_or_hash) | ||||||
|         if isinstance(string_or_id, str): |  | ||||||
|             if len(string_or_id) == 0: |  | ||||||
|                 return True |  | ||||||
|             elif string_or_id in SYMBOLS_BY_STR: |  | ||||||
|                 return True |  | ||||||
|             str_hash = hash_string(string_or_id) |  | ||||||
|         elif _try_coerce_to_hash(string_or_id, &str_hash): |  | ||||||
|             pass |  | ||||||
|         else: |  | ||||||
|             # TODO: Raise an error instead |  | ||||||
|             return self._map.get(string_or_id) is not NULL |  | ||||||
| 
 |  | ||||||
|         if str_hash in SYMBOLS_BY_INT: |         if str_hash in SYMBOLS_BY_INT: | ||||||
|             return True |             return True | ||||||
|         else: |         else: | ||||||
|             return self._map.get(str_hash) is not NULL |             return self._map.get(str_hash) is not NULL | ||||||
| 
 | 
 | ||||||
|     def __iter__(self): |     def __iter__(self) -> Iterator[str]: | ||||||
|         """Iterate over the strings in the store, in order. |         """Iterate over the strings in the store in insertion order. | ||||||
| 
 | 
 | ||||||
|         YIELDS (str): A string in the store. |         RETURNS: An iterable collection of strings. | ||||||
|         """ |         """ | ||||||
|         cdef int i |         return iter(self.keys()) | ||||||
|         cdef hash_t key |  | ||||||
|         for i in range(self.keys.size()): |  | ||||||
|             key = self.keys[i] |  | ||||||
|             utf8str = <Utf8Str*>self._map.get(key) |  | ||||||
|             yield decode_Utf8Str(utf8str) |  | ||||||
|         # TODO: Iterate OOV here? |  | ||||||
| 
 | 
 | ||||||
|     def __reduce__(self): |     def __reduce__(self): | ||||||
|         strings = list(self) |         strings = list(self) | ||||||
|         return (StringStore, (strings,), None, None, None) |         return (StringStore, (strings,), None, None, None) | ||||||
| 
 | 
 | ||||||
|  |     def __len__(self) -> int: | ||||||
|  |         """The number of strings in the store. | ||||||
|  | 
 | ||||||
|  |         RETURNS (int): The number of strings in the store. | ||||||
|  |         """ | ||||||
|  |         return self._keys.size() | ||||||
|  | 
 | ||||||
|  |     def add(self, string: str) -> int: | ||||||
|  |         """Add a string to the StringStore. | ||||||
|  | 
 | ||||||
|  |         string (str): The string to add. | ||||||
|  |         RETURNS (uint64): The string's hash value. | ||||||
|  |         """ | ||||||
|  |         if not isinstance(string, str): | ||||||
|  |             raise TypeError(Errors.E017.format(value_type=type(string))) | ||||||
|  | 
 | ||||||
|  |         if string in SYMBOLS_BY_STR: | ||||||
|  |             return SYMBOLS_BY_STR[string] | ||||||
|  |         else: | ||||||
|  |             return self._intern_str(string) | ||||||
|  | 
 | ||||||
|  |     def as_int(self, string_or_hash: Union[str, int]) -> str: | ||||||
|  |         """If a hash value is passed as the input, return it as-is. If the input | ||||||
|  |         is a string, return its corresponding hash. | ||||||
|  | 
 | ||||||
|  |         string_or_hash (str / int): The string to hash or a hash value. | ||||||
|  |         RETURNS (int): The hash of the string or the input hash value. | ||||||
|  |         """ | ||||||
|  |         if isinstance(string_or_hash, int): | ||||||
|  |             return string_or_hash | ||||||
|  |         else: | ||||||
|  |             return get_string_id(string_or_hash) | ||||||
|  | 
 | ||||||
|  |     def as_string(self, string_or_hash: Union[str, int]) -> str: | ||||||
|  |         """If a string is passed as the input, return it as-is. If the input | ||||||
|  |         is a hash value, return its corresponding string. | ||||||
|  | 
 | ||||||
|  |         string_or_hash (str / int): The hash value to lookup or a string. | ||||||
|  |         RETURNS (str): The stored string or the input string. | ||||||
|  |         """ | ||||||
|  |         if isinstance(string_or_hash, str): | ||||||
|  |             return string_or_hash | ||||||
|  |         else: | ||||||
|  |             return self._get_interned_str(string_or_hash) | ||||||
|  | 
 | ||||||
|  |     def items(self) -> List[Tuple[str, int]]: | ||||||
|  |         """Iterate over the stored strings and their hashes in insertion order. | ||||||
|  | 
 | ||||||
|  |         RETURNS: A list of string-hash pairs. | ||||||
|  |         """ | ||||||
|  |         # Even though we internally store the hashes as keys and the strings as | ||||||
|  |         # values, we invert the order in the public API to keep it consistent with | ||||||
|  |         # the implementation of the `__iter__` method (where we wish to iterate over | ||||||
|  |         # the strings in the store). | ||||||
|  |         cdef int i | ||||||
|  |         pairs = [None] * self._keys.size() | ||||||
|  |         for i in range(self._keys.size()): | ||||||
|  |             str_hash = self._keys[i] | ||||||
|  |             utf8str = <Utf8Str*>self._map.get(str_hash) | ||||||
|  |             pairs[i] = (self._decode_str_repr(utf8str), str_hash) | ||||||
|  |         return pairs | ||||||
|  | 
 | ||||||
|  |     def keys(self) -> List[str]: | ||||||
|  |         """Iterate over the stored strings in insertion order. | ||||||
|  | 
 | ||||||
|  |         RETURNS: A list of strings. | ||||||
|  |         """ | ||||||
|  |         cdef int i | ||||||
|  |         strings = [None] * self._keys.size() | ||||||
|  |         for i in range(self._keys.size()): | ||||||
|  |             utf8str = <Utf8Str*>self._map.get(self._keys[i]) | ||||||
|  |             strings[i] = self._decode_str_repr(utf8str) | ||||||
|  |         return strings | ||||||
|  | 
 | ||||||
|  |     def values(self) -> List[int]: | ||||||
|  |         """Iterate over the stored strings hashes in insertion order. | ||||||
|  | 
 | ||||||
|  |         RETURNS: A list of string hashs. | ||||||
|  |         """ | ||||||
|  |         cdef int i | ||||||
|  |         hashes = [None] * self._keys.size() | ||||||
|  |         for i in range(self._keys.size()): | ||||||
|  |             hashes[i] = self._keys[i] | ||||||
|  |         return hashes | ||||||
|  | 
 | ||||||
|     def to_disk(self, path): |     def to_disk(self, path): | ||||||
|         """Save the current state to a directory. |         """Save the current state to a directory. | ||||||
| 
 | 
 | ||||||
|  | @ -294,24 +202,122 @@ cdef class StringStore: | ||||||
|     def _reset_and_load(self, strings): |     def _reset_and_load(self, strings): | ||||||
|         self.mem = Pool() |         self.mem = Pool() | ||||||
|         self._map = PreshMap() |         self._map = PreshMap() | ||||||
|         self.keys.clear() |         self._keys.clear() | ||||||
|         for string in strings: |         for string in strings: | ||||||
|             self.add(string) |             self.add(string) | ||||||
| 
 | 
 | ||||||
|     cdef const Utf8Str* intern_unicode(self, str py_string): |     def _get_interned_str(self, hash_value: int) -> str: | ||||||
|         # 0 means missing, but we don't bother offsetting the index. |         cdef hash_t str_hash | ||||||
|         cdef bytes byte_string = py_string.encode("utf8") |         if not _try_coerce_to_hash(hash_value, &str_hash): | ||||||
|         return self._intern_utf8(byte_string, len(byte_string), NULL) |             raise TypeError(Errors.E4001.format(expected_types="'int'", received_type=type(hash_value))) | ||||||
| 
 | 
 | ||||||
|     @cython.final |         # Handle reserved symbols and empty strings correctly. | ||||||
|     cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash): |         if str_hash == 0: | ||||||
|  |             return "" | ||||||
|  | 
 | ||||||
|  |         symbol = SYMBOLS_BY_INT.get(str_hash) | ||||||
|  |         if symbol is not None: | ||||||
|  |             return symbol | ||||||
|  | 
 | ||||||
|  |         utf8str = <Utf8Str*>self._map.get(str_hash) | ||||||
|  |         if utf8str is NULL: | ||||||
|  |             raise KeyError(Errors.E018.format(hash_value=str_hash)) | ||||||
|  |         else: | ||||||
|  |             return self._decode_str_repr(utf8str) | ||||||
|  | 
 | ||||||
|  |     cdef hash_t _intern_str(self, str string): | ||||||
|         # TODO: This function's API/behaviour is an unholy mess... |         # TODO: This function's API/behaviour is an unholy mess... | ||||||
|         # 0 means missing, but we don't bother offsetting the index. |         # 0 means missing, but we don't bother offsetting the index. | ||||||
|         cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else hash_utf8(utf8_string, length) |         chars = string.encode('utf-8') | ||||||
|  |         cdef hash_t key = hash64(<unsigned char*>chars, len(chars), 1) | ||||||
|         cdef Utf8Str* value = <Utf8Str*>self._map.get(key) |         cdef Utf8Str* value = <Utf8Str*>self._map.get(key) | ||||||
|         if value is not NULL: |         if value is not NULL: | ||||||
|             return value |             return key | ||||||
|         value = _allocate(self.mem, <unsigned char*>utf8_string, length) | 
 | ||||||
|  |         value = self._allocate_str_repr(<unsigned char*>chars, len(chars)) | ||||||
|         self._map.set(key, value) |         self._map.set(key, value) | ||||||
|         self.keys.push_back(key) |         self._keys.push_back(key) | ||||||
|         return value |         return key | ||||||
|  | 
 | ||||||
|  |     cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *: | ||||||
|  |         cdef int n_length_bytes | ||||||
|  |         cdef int i | ||||||
|  |         cdef Utf8Str* string = <Utf8Str*>self.mem.alloc(1, sizeof(Utf8Str)) | ||||||
|  |         cdef uint32_t ulength = length | ||||||
|  |         if length < sizeof(string.s): | ||||||
|  |             string.s[0] = <unsigned char>length | ||||||
|  |             memcpy(&string.s[1], chars, length) | ||||||
|  |             return string | ||||||
|  |         elif length < 255: | ||||||
|  |             string.p = <unsigned char*>self.mem.alloc(length + 1, sizeof(unsigned char)) | ||||||
|  |             string.p[0] = length | ||||||
|  |             memcpy(&string.p[1], chars, length) | ||||||
|  |             return string | ||||||
|  |         else: | ||||||
|  |             i = 0 | ||||||
|  |             n_length_bytes = (length // 255) + 1 | ||||||
|  |             string.p = <unsigned char*>self.mem.alloc(length + n_length_bytes, sizeof(unsigned char)) | ||||||
|  |             for i in range(n_length_bytes-1): | ||||||
|  |                 string.p[i] = 255 | ||||||
|  |             string.p[n_length_bytes-1] = length % 255 | ||||||
|  |             memcpy(&string.p[n_length_bytes], chars, length) | ||||||
|  |             return string | ||||||
|  | 
 | ||||||
|  |     cdef str _decode_str_repr(self, const Utf8Str* string): | ||||||
|  |         cdef int i, length | ||||||
|  |         if string.s[0] < sizeof(string.s) and string.s[0] != 0: | ||||||
|  |             return string.s[1:string.s[0]+1].decode('utf-8') | ||||||
|  |         elif string.p[0] < 255: | ||||||
|  |             return string.p[1:string.p[0]+1].decode('utf-8') | ||||||
|  |         else: | ||||||
|  |             i = 0 | ||||||
|  |             length = 0 | ||||||
|  |             while string.p[i] == 255: | ||||||
|  |                 i += 1 | ||||||
|  |                 length += 255 | ||||||
|  |             length += string.p[i] | ||||||
|  |             i += 1 | ||||||
|  |             return string.p[i:length + i].decode('utf-8') | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cpdef hash_t hash_string(object string) except -1: | ||||||
|  |     if not isinstance(string, str): | ||||||
|  |         raise TypeError(Errors.E4001.format(expected_types="'str'", received_type=type(string))) | ||||||
|  | 
 | ||||||
|  |     # Handle reserved symbols and empty strings correctly. | ||||||
|  |     if len(string) == 0: | ||||||
|  |         return 0 | ||||||
|  | 
 | ||||||
|  |     symbol = SYMBOLS_BY_STR.get(string) | ||||||
|  |     if symbol is not None: | ||||||
|  |         return symbol | ||||||
|  | 
 | ||||||
|  |     chars = string.encode('utf-8') | ||||||
|  |     return hash64(<unsigned char*>chars, len(chars), 1) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cpdef hash_t get_string_id(object string_or_hash) except -1: | ||||||
|  |     cdef hash_t str_hash | ||||||
|  | 
 | ||||||
|  |     try: | ||||||
|  |         return hash_string(string_or_hash) | ||||||
|  |     except: | ||||||
|  |         if _try_coerce_to_hash(string_or_hash, &str_hash): | ||||||
|  |             # Coerce the integral key to the expected primitive hash type. | ||||||
|  |             # This ensures that custom/overloaded "primitive" data types | ||||||
|  |             # such as those implemented by numpy are not inadvertently used | ||||||
|  |             # downsteam (as these are internally implemented as custom PyObjects | ||||||
|  |             # whose comparison operators can incur a significant overhead). | ||||||
|  |             return str_hash | ||||||
|  |         else: | ||||||
|  |             raise TypeError(Errors.E4001.format(expected_types="'str','int'", received_type=type(string_or_hash))) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)` | ||||||
|  | cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash): | ||||||
|  |     try: | ||||||
|  |         out_hash[0] = key | ||||||
|  |         return True | ||||||
|  |     except: | ||||||
|  |         return False | ||||||
|  | 
 | ||||||
|  |  | ||||||
|  | @ -24,6 +24,14 @@ def test_stringstore_from_api_docs(stringstore): | ||||||
|     stringstore.add("orange") |     stringstore.add("orange") | ||||||
|     all_strings = [s for s in stringstore] |     all_strings = [s for s in stringstore] | ||||||
|     assert all_strings == ["apple", "orange"] |     assert all_strings == ["apple", "orange"] | ||||||
|  |     assert all_strings == list(stringstore.keys()) | ||||||
|  |     all_strings_and_hashes = list(stringstore.items()) | ||||||
|  |     assert all_strings_and_hashes == [ | ||||||
|  |         ("apple", 8566208034543834098), | ||||||
|  |         ("orange", 2208928596161743350), | ||||||
|  |     ] | ||||||
|  |     all_hashes = list(stringstore.values()) | ||||||
|  |     assert all_hashes == [8566208034543834098, 2208928596161743350] | ||||||
|     banana_hash = stringstore.add("banana") |     banana_hash = stringstore.add("banana") | ||||||
|     assert len(stringstore) == 3 |     assert len(stringstore) == 3 | ||||||
|     assert banana_hash == 2525716904149915114 |     assert banana_hash == 2525716904149915114 | ||||||
|  | @ -31,12 +39,25 @@ def test_stringstore_from_api_docs(stringstore): | ||||||
|     assert stringstore["banana"] == banana_hash |     assert stringstore["banana"] == banana_hash | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("text1,text2,text3", [(b"Hello", b"goodbye", b"hello")]) | @pytest.mark.parametrize( | ||||||
| def test_stringstore_save_bytes(stringstore, text1, text2, text3): |     "val_bytes,val_float,val_list,val_text,val_hash", | ||||||
|     key = stringstore.add(text1) |     [(b"Hello", 1.1, ["abc"], "apple", 8566208034543834098)], | ||||||
|     assert stringstore[text1] == key | ) | ||||||
|     assert stringstore[text2] != key | def test_stringstore_type_checking( | ||||||
|     assert stringstore[text3] != key |     stringstore, val_bytes, val_float, val_list, val_text, val_hash | ||||||
|  | ): | ||||||
|  |     with pytest.raises(TypeError): | ||||||
|  |         assert stringstore[val_bytes] | ||||||
|  | 
 | ||||||
|  |     with pytest.raises(TypeError): | ||||||
|  |         stringstore.add(val_float) | ||||||
|  | 
 | ||||||
|  |     with pytest.raises(TypeError): | ||||||
|  |         assert val_list not in stringstore | ||||||
|  | 
 | ||||||
|  |     key = stringstore.add(val_text) | ||||||
|  |     assert val_hash == key | ||||||
|  |     assert stringstore[val_hash] == val_text | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("text1,text2,text3", [("Hello", "goodbye", "hello")]) | @pytest.mark.parametrize("text1,text2,text3", [("Hello", "goodbye", "hello")]) | ||||||
|  | @ -47,19 +68,19 @@ def test_stringstore_save_unicode(stringstore, text1, text2, text3): | ||||||
|     assert stringstore[text3] != key |     assert stringstore[text3] != key | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("text", [b"A"]) | @pytest.mark.parametrize("text", ["A"]) | ||||||
| def test_stringstore_retrieve_id(stringstore, text): | def test_stringstore_retrieve_id(stringstore, text): | ||||||
|     key = stringstore.add(text) |     key = stringstore.add(text) | ||||||
|     assert len(stringstore) == 1 |     assert len(stringstore) == 1 | ||||||
|     assert stringstore[key] == text.decode("utf8") |     assert stringstore[key] == text | ||||||
|     with pytest.raises(KeyError): |     with pytest.raises(KeyError): | ||||||
|         stringstore[20000] |         stringstore[20000] | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.mark.parametrize("text1,text2", [(b"0123456789", b"A")]) | @pytest.mark.parametrize("text1,text2", [("0123456789", "A")]) | ||||||
| def test_stringstore_med_string(stringstore, text1, text2): | def test_stringstore_med_string(stringstore, text1, text2): | ||||||
|     store = stringstore.add(text1) |     store = stringstore.add(text1) | ||||||
|     assert stringstore[store] == text1.decode("utf8") |     assert stringstore[store] == text1 | ||||||
|     stringstore.add(text2) |     stringstore.add(text2) | ||||||
|     assert stringstore[text1] == store |     assert stringstore[text1] == store | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -12,7 +12,7 @@ from murmurhash.mrmr cimport hash64 | ||||||
| 
 | 
 | ||||||
| from .. import Errors | from .. import Errors | ||||||
| from ..typedefs cimport hash_t | from ..typedefs cimport hash_t | ||||||
| from ..strings import get_string_id | from ..strings cimport get_string_id | ||||||
| from ..structs cimport EdgeC, GraphC | from ..structs cimport EdgeC, GraphC | ||||||
| from .token import Token | from .token import Token | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -18,7 +18,7 @@ from .underscore import is_writable_attr | ||||||
| from ..attrs import intify_attrs | from ..attrs import intify_attrs | ||||||
| from ..util import SimpleFrozenDict | from ..util import SimpleFrozenDict | ||||||
| from ..errors import Errors | from ..errors import Errors | ||||||
| from ..strings import get_string_id | from ..strings cimport get_string_id | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class Retokenizer: | cdef class Retokenizer: | ||||||
|  |  | ||||||
|  | @ -40,7 +40,8 @@ Get the number of strings in the store. | ||||||
| 
 | 
 | ||||||
| ## StringStore.\_\_getitem\_\_ {#getitem tag="method"} | ## StringStore.\_\_getitem\_\_ {#getitem tag="method"} | ||||||
| 
 | 
 | ||||||
| Retrieve a string from a given hash, or vice versa. | Retrieve a string from a given hash. If a string is passed as the input, add it | ||||||
|  | to the store and return its hash. | ||||||
| 
 | 
 | ||||||
| > #### Example | > #### Example | ||||||
| > | > | ||||||
|  | @ -51,14 +52,14 @@ Retrieve a string from a given hash, or vice versa. | ||||||
| > assert stringstore[apple_hash] == "apple" | > assert stringstore[apple_hash] == "apple" | ||||||
| > ``` | > ``` | ||||||
| 
 | 
 | ||||||
| | Name           | Description                                     | | | Name             | Description                                                                  | | ||||||
| | -------------- | ----------------------------------------------- | | | ---------------- | ---------------------------------------------------------------------------- | | ||||||
| | `string_or_id` | The value to encode. ~~Union[bytes, str, int]~~ | | | `string_or_hash` | The hash value to lookup or the string to store. ~~Union[str, int]~~         | | ||||||
| | **RETURNS**    | The value to be retrieved. ~~Union[str, int]~~  | | | **RETURNS**      | The stored string or the hash of the newly added string. ~~Union[str, int]~~ | | ||||||
| 
 | 
 | ||||||
| ## StringStore.\_\_contains\_\_ {#contains tag="method"} | ## StringStore.\_\_contains\_\_ {#contains tag="method"} | ||||||
| 
 | 
 | ||||||
| Check whether a string is in the store. | Check whether a string or a hash is in the store. | ||||||
| 
 | 
 | ||||||
| > #### Example | > #### Example | ||||||
| > | > | ||||||
|  | @ -68,15 +69,14 @@ Check whether a string is in the store. | ||||||
| > assert not "cherry" in stringstore | > assert not "cherry" in stringstore | ||||||
| > ``` | > ``` | ||||||
| 
 | 
 | ||||||
| | Name        | Description                                     | | | Name             | Description                                             | | ||||||
| | ----------- | ----------------------------------------------- | | | ---------------- | ------------------------------------------------------- | | ||||||
| | `string`    | The string to check. ~~str~~                    | | | `string_or_hash` | The string or hash to check. ~~Union[str, int]~~        | | ||||||
| | **RETURNS** | Whether the store contains the string. ~~bool~~ | | | **RETURNS**      | Whether the store contains the string or hash. ~~bool~~ | | ||||||
| 
 | 
 | ||||||
| ## StringStore.\_\_iter\_\_ {#iter tag="method"} | ## StringStore.\_\_iter\_\_ {#iter tag="method"} | ||||||
| 
 | 
 | ||||||
| Iterate over the strings in the store, in order. Note that a newly initialized | Iterate over the stored strings in insertion order. | ||||||
| store will always include an empty string `""` at position `0`. |  | ||||||
| 
 | 
 | ||||||
| > #### Example | > #### Example | ||||||
| > | > | ||||||
|  | @ -86,11 +86,59 @@ store will always include an empty string `""` at position `0`. | ||||||
| > assert all_strings == ["apple", "orange"] | > assert all_strings == ["apple", "orange"] | ||||||
| > ``` | > ``` | ||||||
| 
 | 
 | ||||||
| | Name       | Description                    | | | Name        | Description                    | | ||||||
| | ---------- | ------------------------------ | | | ----------- | ------------------------------ | | ||||||
| | **YIELDS** | A string in the store. ~~str~~ | | | **RETURNS** | A string in the store. ~~str~~ | | ||||||
| 
 | 
 | ||||||
| ## StringStore.add {#add tag="method" new="2"} | ## StringStore.items {#iter tag="method" new="4"} | ||||||
|  | 
 | ||||||
|  | Iterate over the stored string-hash pairs in insertion order. | ||||||
|  | 
 | ||||||
|  | > #### Example | ||||||
|  | > | ||||||
|  | > ```python | ||||||
|  | > stringstore = StringStore(["apple", "orange"]) | ||||||
|  | > all_strings_and_hashes = stringstore.items() | ||||||
|  | > assert all_strings_and_hashes == [("apple", 8566208034543834098), ("orange", 2208928596161743350)] | ||||||
|  | > ``` | ||||||
|  | 
 | ||||||
|  | | Name        | Description                                            | | ||||||
|  | | ----------- | ------------------------------------------------------ | | ||||||
|  | | **RETURNS** | A list of string-hash pairs. ~~List[Tuple[str, int]]~~ | | ||||||
|  | 
 | ||||||
|  | ## StringStore.keys {#iter tag="method" new="4"} | ||||||
|  | 
 | ||||||
|  | Iterate over the stored strings in insertion order. | ||||||
|  | 
 | ||||||
|  | > #### Example | ||||||
|  | > | ||||||
|  | > ```python | ||||||
|  | > stringstore = StringStore(["apple", "orange"]) | ||||||
|  | > all_strings = stringstore.keys() | ||||||
|  | > assert all_strings == ["apple", "orange"] | ||||||
|  | > ``` | ||||||
|  | 
 | ||||||
|  | | Name        | Description                      | | ||||||
|  | | ----------- | -------------------------------- | | ||||||
|  | | **RETURNS** | A list of strings. ~~List[str]~~ | | ||||||
|  | 
 | ||||||
|  | ## StringStore.values {#iter tag="method" new="4"} | ||||||
|  | 
 | ||||||
|  | Iterate over the stored string hashes in insertion order. | ||||||
|  | 
 | ||||||
|  | > #### Example | ||||||
|  | > | ||||||
|  | > ```python | ||||||
|  | > stringstore = StringStore(["apple", "orange"]) | ||||||
|  | > all_hashes = stringstore.values() | ||||||
|  | > assert all_hashes == [8566208034543834098, 2208928596161743350] | ||||||
|  | > ``` | ||||||
|  | 
 | ||||||
|  | | Name        | Description                            | | ||||||
|  | | ----------- | -------------------------------------- | | ||||||
|  | | **RETURNS** | A list of string hashes. ~~List[int]~~ | | ||||||
|  | 
 | ||||||
|  | ## StringStore.add {#add tag="method"} | ||||||
| 
 | 
 | ||||||
| Add a string to the `StringStore`. | Add a string to the `StringStore`. | ||||||
| 
 | 
 | ||||||
|  | @ -110,7 +158,7 @@ Add a string to the `StringStore`. | ||||||
| | `string`    | The string to add. ~~str~~       | | | `string`    | The string to add. ~~str~~       | | ||||||
| | **RETURNS** | The string's hash value. ~~int~~ | | | **RETURNS** | The string's hash value. ~~int~~ | | ||||||
| 
 | 
 | ||||||
| ## StringStore.to_disk {#to_disk tag="method" new="2"} | ## StringStore.to_disk {#to_disk tag="method"} | ||||||
| 
 | 
 | ||||||
| Save the current state to a directory. | Save the current state to a directory. | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user