mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-30 01:43:21 +03:00
StringStore
refactoring (#11344)
* `strings`: Remove unused `hash32_utf8` function
* `strings`: Make `hash_utf8` and `decode_Utf8Str` private
* `strings`: Reorganize private functions
* 'strings': Raise error when non-string/-int types are passed to functions that don't accept them
* `strings`: Add `items()` method, add type hints, remove unused methods, restrict inputs to specific types, reorganize methods
* `Morphology`: Use `StringStore.items()` to enumerate features when pickling
* `test_stringstore`: Update pre-Python 3 tests
* Update `StringStore` docs
* Fix `get_string_id` imports
* Replace redundant test with tests for type checking
* Rename `_retrieve_interned_str`, remove `.get` default arg
* Add `get_string_id` to `strings.pyi`
Remove `mypy` ignore directives from imports of the above
* `strings.pyi`: Replace functions that consume `Union`-typed params with overloads
* `strings.pyi`: Revert some function signatures
* Update `SYMBOLS_BY_INT` lookups and error codes post-merge
* Revert clobbered change introduced in a previous merge
* Remove unnecessary type hint
* Invert tuple order in `StringStore.items()`
* Add test for `StringStore.items()`
* Revert "`Morphology`: Use `StringStore.items()` to enumerate features when pickling"
This reverts commit 1af9510ceb
.
* Rename `keys` and `key_map`
* Add `keys()` and `values()`
* Add comment about the inverted key-value semantics in the API
* Fix type hints
* Implement `keys()`, `values()`, `items()` without generators
* Fix type hints, remove unnecessary boxing
* Update docs
* Simplify `keys/values/items()` impl
* `mypy` fix
* Fix error message, doc fixes
This commit is contained in:
parent
c6704f368c
commit
446a3ecf34
|
@ -252,7 +252,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}")
|
E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}")
|
||||||
E016 = ("MultitaskObjective target should be function or one of: dep, "
|
E016 = ("MultitaskObjective target should be function or one of: dep, "
|
||||||
"tag, ent, dep_tag_offset, ent_tag.")
|
"tag, ent, dep_tag_offset, ent_tag.")
|
||||||
E017 = ("Can only add unicode or bytes. Got type: {value_type}")
|
E017 = ("Can only add 'str' inputs to StringStore. Got type: {value_type}")
|
||||||
E018 = ("Can't retrieve string for hash '{hash_value}'. This usually "
|
E018 = ("Can't retrieve string for hash '{hash_value}'. This usually "
|
||||||
"refers to an issue with the `Vocab` or `StringStore`.")
|
"refers to an issue with the `Vocab` or `StringStore`.")
|
||||||
E019 = ("Can't create transition with unknown action ID: {action}. Action "
|
E019 = ("Can't create transition with unknown action ID: {action}. Action "
|
||||||
|
@ -955,6 +955,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
|
|
||||||
# v4 error strings
|
# v4 error strings
|
||||||
E4000 = ("Expected a Doc as input, but got: '{type}'")
|
E4000 = ("Expected a Doc as input, but got: '{type}'")
|
||||||
|
E4001 = ("Expected input to be one of the following types: ({expected_types}), "
|
||||||
|
"but got '{received_type}'")
|
||||||
|
|
||||||
|
|
||||||
# Deprecated model shortcuts, only used in errors and warnings
|
# Deprecated model shortcuts, only used in errors and warnings
|
||||||
|
|
|
@ -22,7 +22,7 @@ from ..attrs cimport ID, attr_id_t, NULL_ATTR, ORTH, POS, TAG, DEP, LEMMA, MORPH
|
||||||
|
|
||||||
from ..schemas import validate_token_pattern
|
from ..schemas import validate_token_pattern
|
||||||
from ..errors import Errors, MatchPatternError, Warnings
|
from ..errors import Errors, MatchPatternError, Warnings
|
||||||
from ..strings import get_string_id
|
from ..strings cimport get_string_id
|
||||||
from ..attrs import IDS
|
from ..attrs import IDS
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from libc.stdint cimport int64_t
|
from libc.stdint cimport int64_t, uint32_t
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libcpp.set cimport set
|
from libcpp.set cimport set
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
|
@ -7,13 +7,6 @@ from murmurhash.mrmr cimport hash64
|
||||||
|
|
||||||
from .typedefs cimport attr_t, hash_t
|
from .typedefs cimport attr_t, hash_t
|
||||||
|
|
||||||
|
|
||||||
cpdef hash_t hash_string(str string) except 0
|
|
||||||
cdef hash_t hash_utf8(char* utf8_string, int length) nogil
|
|
||||||
|
|
||||||
cdef str decode_Utf8Str(const Utf8Str* string)
|
|
||||||
|
|
||||||
|
|
||||||
ctypedef union Utf8Str:
|
ctypedef union Utf8Str:
|
||||||
unsigned char[8] s
|
unsigned char[8] s
|
||||||
unsigned char* p
|
unsigned char* p
|
||||||
|
@ -21,9 +14,13 @@ ctypedef union Utf8Str:
|
||||||
|
|
||||||
cdef class StringStore:
|
cdef class StringStore:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
|
cdef vector[hash_t] _keys
|
||||||
|
cdef PreshMap _map
|
||||||
|
|
||||||
cdef vector[hash_t] keys
|
cdef hash_t _intern_str(self, str string)
|
||||||
cdef public PreshMap _map
|
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *
|
||||||
|
cdef str _decode_str_repr(self, const Utf8Str* string)
|
||||||
|
|
||||||
cdef const Utf8Str* intern_unicode(self, str py_string)
|
|
||||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash)
|
cpdef hash_t hash_string(object string) except -1
|
||||||
|
cpdef hash_t get_string_id(object string_or_hash) except -1
|
||||||
|
|
|
@ -1,21 +1,20 @@
|
||||||
from typing import Optional, Iterable, Iterator, Union, Any, overload
|
from typing import List, Optional, Iterable, Iterator, Union, Any, Tuple, overload
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
def get_string_id(key: Union[str, int]) -> int: ...
|
|
||||||
|
|
||||||
class StringStore:
|
class StringStore:
|
||||||
def __init__(
|
def __init__(self, strings: Optional[Iterable[str]]) -> None: ...
|
||||||
self, strings: Optional[Iterable[str]] = ..., freeze: bool = ...
|
|
||||||
) -> None: ...
|
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, string_or_id: Union[bytes, str]) -> int: ...
|
def __getitem__(self, string_or_hash: str) -> int: ...
|
||||||
@overload
|
@overload
|
||||||
def __getitem__(self, string_or_id: int) -> str: ...
|
def __getitem__(self, string_or_hash: int) -> str: ...
|
||||||
def as_int(self, key: Union[bytes, str, int]) -> int: ...
|
def as_int(self, string_or_hash: Union[str, int]) -> int: ...
|
||||||
def as_string(self, key: Union[bytes, str, int]) -> str: ...
|
def as_string(self, string_or_hash: Union[str, int]) -> str: ...
|
||||||
def add(self, string: str) -> int: ...
|
def add(self, string: str) -> int: ...
|
||||||
|
def items(self) -> List[Tuple[str, int]]: ...
|
||||||
|
def keys(self) -> List[str]: ...
|
||||||
|
def values(self) -> List[int]: ...
|
||||||
def __len__(self) -> int: ...
|
def __len__(self) -> int: ...
|
||||||
def __contains__(self, string: str) -> bool: ...
|
def __contains__(self, string_or_hash: Union[str, int]) -> bool: ...
|
||||||
def __iter__(self) -> Iterator[str]: ...
|
def __iter__(self) -> Iterator[str]: ...
|
||||||
def __reduce__(self) -> Any: ...
|
def __reduce__(self) -> Any: ...
|
||||||
def to_disk(self, path: Union[str, Path]) -> None: ...
|
def to_disk(self, path: Union[str, Path]) -> None: ...
|
||||||
|
@ -23,3 +22,5 @@ class StringStore:
|
||||||
def to_bytes(self, **kwargs: Any) -> bytes: ...
|
def to_bytes(self, **kwargs: Any) -> bytes: ...
|
||||||
def from_bytes(self, bytes_data: bytes, **kwargs: Any) -> StringStore: ...
|
def from_bytes(self, bytes_data: bytes, **kwargs: Any) -> StringStore: ...
|
||||||
def _reset_and_load(self, strings: Iterable[str]) -> None: ...
|
def _reset_and_load(self, strings: Iterable[str]) -> None: ...
|
||||||
|
|
||||||
|
def get_string_id(string_or_hash: Union[str, int]) -> int: ...
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
|
from typing import Optional, Union, Iterable, Tuple, Callable, Any, List, Iterator
|
||||||
cimport cython
|
cimport cython
|
||||||
from libc.string cimport memcpy
|
from libc.string cimport memcpy
|
||||||
from libcpp.set cimport set
|
from libcpp.set cimport set
|
||||||
from libc.stdint cimport uint32_t
|
from libc.stdint cimport uint32_t
|
||||||
from murmurhash.mrmr cimport hash64, hash32
|
from murmurhash.mrmr cimport hash64
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
@ -14,105 +15,13 @@ from .symbols import NAMES as SYMBOLS_BY_INT
|
||||||
from .errors import Errors
|
from .errors import Errors
|
||||||
from . import util
|
from . import util
|
||||||
|
|
||||||
# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)`
|
|
||||||
cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
|
|
||||||
try:
|
|
||||||
out_hash[0] = key
|
|
||||||
return True
|
|
||||||
except:
|
|
||||||
return False
|
|
||||||
|
|
||||||
def get_string_id(key):
|
|
||||||
"""Get a string ID, handling the reserved symbols correctly. If the key is
|
|
||||||
already an ID, return it.
|
|
||||||
|
|
||||||
This function optimises for convenience over performance, so shouldn't be
|
|
||||||
used in tight loops.
|
|
||||||
"""
|
|
||||||
cdef hash_t str_hash
|
|
||||||
if isinstance(key, str):
|
|
||||||
if len(key) == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
symbol = SYMBOLS_BY_STR.get(key, None)
|
|
||||||
if symbol is not None:
|
|
||||||
return symbol
|
|
||||||
else:
|
|
||||||
chars = key.encode("utf8")
|
|
||||||
return hash_utf8(chars, len(chars))
|
|
||||||
elif _try_coerce_to_hash(key, &str_hash):
|
|
||||||
# Coerce the integral key to the expected primitive hash type.
|
|
||||||
# This ensures that custom/overloaded "primitive" data types
|
|
||||||
# such as those implemented by numpy are not inadvertently used
|
|
||||||
# downsteam (as these are internally implemented as custom PyObjects
|
|
||||||
# whose comparison operators can incur a significant overhead).
|
|
||||||
return str_hash
|
|
||||||
else:
|
|
||||||
# TODO: Raise an error instead
|
|
||||||
return key
|
|
||||||
|
|
||||||
|
|
||||||
cpdef hash_t hash_string(str string) except 0:
|
|
||||||
chars = string.encode("utf8")
|
|
||||||
return hash_utf8(chars, len(chars))
|
|
||||||
|
|
||||||
|
|
||||||
cdef hash_t hash_utf8(char* utf8_string, int length) nogil:
|
|
||||||
return hash64(utf8_string, length, 1)
|
|
||||||
|
|
||||||
|
|
||||||
cdef uint32_t hash32_utf8(char* utf8_string, int length) nogil:
|
|
||||||
return hash32(utf8_string, length, 1)
|
|
||||||
|
|
||||||
|
|
||||||
cdef str decode_Utf8Str(const Utf8Str* string):
|
|
||||||
cdef int i, length
|
|
||||||
if string.s[0] < sizeof(string.s) and string.s[0] != 0:
|
|
||||||
return string.s[1:string.s[0]+1].decode("utf8")
|
|
||||||
elif string.p[0] < 255:
|
|
||||||
return string.p[1:string.p[0]+1].decode("utf8")
|
|
||||||
else:
|
|
||||||
i = 0
|
|
||||||
length = 0
|
|
||||||
while string.p[i] == 255:
|
|
||||||
i += 1
|
|
||||||
length += 255
|
|
||||||
length += string.p[i]
|
|
||||||
i += 1
|
|
||||||
return string.p[i:length + i].decode("utf8")
|
|
||||||
|
|
||||||
|
|
||||||
cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *:
|
|
||||||
cdef int n_length_bytes
|
|
||||||
cdef int i
|
|
||||||
cdef Utf8Str* string = <Utf8Str*>mem.alloc(1, sizeof(Utf8Str))
|
|
||||||
cdef uint32_t ulength = length
|
|
||||||
if length < sizeof(string.s):
|
|
||||||
string.s[0] = <unsigned char>length
|
|
||||||
memcpy(&string.s[1], chars, length)
|
|
||||||
return string
|
|
||||||
elif length < 255:
|
|
||||||
string.p = <unsigned char*>mem.alloc(length + 1, sizeof(unsigned char))
|
|
||||||
string.p[0] = length
|
|
||||||
memcpy(&string.p[1], chars, length)
|
|
||||||
return string
|
|
||||||
else:
|
|
||||||
i = 0
|
|
||||||
n_length_bytes = (length // 255) + 1
|
|
||||||
string.p = <unsigned char*>mem.alloc(length + n_length_bytes, sizeof(unsigned char))
|
|
||||||
for i in range(n_length_bytes-1):
|
|
||||||
string.p[i] = 255
|
|
||||||
string.p[n_length_bytes-1] = length % 255
|
|
||||||
memcpy(&string.p[n_length_bytes], chars, length)
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
cdef class StringStore:
|
cdef class StringStore:
|
||||||
"""Look up strings by 64-bit hashes.
|
"""Look up strings by 64-bit hashes. Implicitly handles reserved symbols.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/stringstore
|
DOCS: https://spacy.io/api/stringstore
|
||||||
"""
|
"""
|
||||||
def __init__(self, strings=None, freeze=False):
|
def __init__(self, strings: Optional[Iterable[str]] = None):
|
||||||
"""Create the StringStore.
|
"""Create the StringStore.
|
||||||
|
|
||||||
strings (iterable): A sequence of unicode strings to add to the store.
|
strings (iterable): A sequence of unicode strings to add to the store.
|
||||||
|
@ -123,128 +32,127 @@ cdef class StringStore:
|
||||||
for string in strings:
|
for string in strings:
|
||||||
self.add(string)
|
self.add(string)
|
||||||
|
|
||||||
def __getitem__(self, object string_or_id):
|
def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]:
|
||||||
"""Retrieve a string from a given hash, or vice versa.
|
"""Retrieve a string from a given hash. If a string
|
||||||
|
is passed as the input, add it to the store and return
|
||||||
|
its hash.
|
||||||
|
|
||||||
string_or_id (bytes, str or uint64): The value to encode.
|
string_or_hash (int / str): The hash value to lookup or the string to store.
|
||||||
Returns (str / uint64): The value to be retrieved.
|
RETURNS (str / int): The stored string or the hash of the newly added string.
|
||||||
"""
|
"""
|
||||||
cdef hash_t str_hash
|
if isinstance(string_or_hash, str):
|
||||||
cdef Utf8Str* utf8str = NULL
|
return self.add(string_or_hash)
|
||||||
|
|
||||||
if isinstance(string_or_id, str):
|
|
||||||
if len(string_or_id) == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# Return early if the string is found in the symbols LUT.
|
|
||||||
symbol = SYMBOLS_BY_STR.get(string_or_id, None)
|
|
||||||
if symbol is not None:
|
|
||||||
return symbol
|
|
||||||
else:
|
|
||||||
return hash_string(string_or_id)
|
|
||||||
elif isinstance(string_or_id, bytes):
|
|
||||||
return hash_utf8(string_or_id, len(string_or_id))
|
|
||||||
elif _try_coerce_to_hash(string_or_id, &str_hash):
|
|
||||||
if str_hash == 0:
|
|
||||||
return ""
|
|
||||||
elif str_hash in SYMBOLS_BY_INT:
|
|
||||||
return SYMBOLS_BY_INT[str_hash]
|
|
||||||
else:
|
|
||||||
utf8str = <Utf8Str*>self._map.get(str_hash)
|
|
||||||
else:
|
else:
|
||||||
# TODO: Raise an error instead
|
return self._get_interned_str(string_or_hash)
|
||||||
utf8str = <Utf8Str*>self._map.get(string_or_id)
|
|
||||||
|
|
||||||
if utf8str is NULL:
|
def __contains__(self, string_or_hash: Union[str, int]) -> bool:
|
||||||
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
"""Check whether a string or a hash is in the store.
|
||||||
else:
|
|
||||||
return decode_Utf8Str(utf8str)
|
|
||||||
|
|
||||||
def as_int(self, key):
|
string (str / int): The string/hash to check.
|
||||||
"""If key is an int, return it; otherwise, get the int value."""
|
|
||||||
if not isinstance(key, str):
|
|
||||||
return key
|
|
||||||
else:
|
|
||||||
return self[key]
|
|
||||||
|
|
||||||
def as_string(self, key):
|
|
||||||
"""If key is a string, return it; otherwise, get the string value."""
|
|
||||||
if isinstance(key, str):
|
|
||||||
return key
|
|
||||||
else:
|
|
||||||
return self[key]
|
|
||||||
|
|
||||||
def add(self, string):
|
|
||||||
"""Add a string to the StringStore.
|
|
||||||
|
|
||||||
string (str): The string to add.
|
|
||||||
RETURNS (uint64): The string's hash value.
|
|
||||||
"""
|
|
||||||
cdef hash_t str_hash
|
|
||||||
if isinstance(string, str):
|
|
||||||
if string in SYMBOLS_BY_STR:
|
|
||||||
return SYMBOLS_BY_STR[string]
|
|
||||||
|
|
||||||
string = string.encode("utf8")
|
|
||||||
str_hash = hash_utf8(string, len(string))
|
|
||||||
self._intern_utf8(string, len(string), &str_hash)
|
|
||||||
elif isinstance(string, bytes):
|
|
||||||
if string in SYMBOLS_BY_STR:
|
|
||||||
return SYMBOLS_BY_STR[string]
|
|
||||||
str_hash = hash_utf8(string, len(string))
|
|
||||||
self._intern_utf8(string, len(string), &str_hash)
|
|
||||||
else:
|
|
||||||
raise TypeError(Errors.E017.format(value_type=type(string)))
|
|
||||||
return str_hash
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
"""The number of strings in the store.
|
|
||||||
|
|
||||||
RETURNS (int): The number of strings in the store.
|
|
||||||
"""
|
|
||||||
return self.keys.size()
|
|
||||||
|
|
||||||
def __contains__(self, string_or_id not None):
|
|
||||||
"""Check whether a string or ID is in the store.
|
|
||||||
|
|
||||||
string_or_id (str or int): The string to check.
|
|
||||||
RETURNS (bool): Whether the store contains the string.
|
RETURNS (bool): Whether the store contains the string.
|
||||||
"""
|
"""
|
||||||
cdef hash_t str_hash
|
cdef hash_t str_hash = get_string_id(string_or_hash)
|
||||||
if isinstance(string_or_id, str):
|
|
||||||
if len(string_or_id) == 0:
|
|
||||||
return True
|
|
||||||
elif string_or_id in SYMBOLS_BY_STR:
|
|
||||||
return True
|
|
||||||
str_hash = hash_string(string_or_id)
|
|
||||||
elif _try_coerce_to_hash(string_or_id, &str_hash):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# TODO: Raise an error instead
|
|
||||||
return self._map.get(string_or_id) is not NULL
|
|
||||||
|
|
||||||
if str_hash in SYMBOLS_BY_INT:
|
if str_hash in SYMBOLS_BY_INT:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return self._map.get(str_hash) is not NULL
|
return self._map.get(str_hash) is not NULL
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[str]:
|
||||||
"""Iterate over the strings in the store, in order.
|
"""Iterate over the strings in the store in insertion order.
|
||||||
|
|
||||||
YIELDS (str): A string in the store.
|
RETURNS: An iterable collection of strings.
|
||||||
"""
|
"""
|
||||||
cdef int i
|
return iter(self.keys())
|
||||||
cdef hash_t key
|
|
||||||
for i in range(self.keys.size()):
|
|
||||||
key = self.keys[i]
|
|
||||||
utf8str = <Utf8Str*>self._map.get(key)
|
|
||||||
yield decode_Utf8Str(utf8str)
|
|
||||||
# TODO: Iterate OOV here?
|
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
strings = list(self)
|
strings = list(self)
|
||||||
return (StringStore, (strings,), None, None, None)
|
return (StringStore, (strings,), None, None, None)
|
||||||
|
|
||||||
|
def __len__(self) -> int:
|
||||||
|
"""The number of strings in the store.
|
||||||
|
|
||||||
|
RETURNS (int): The number of strings in the store.
|
||||||
|
"""
|
||||||
|
return self._keys.size()
|
||||||
|
|
||||||
|
def add(self, string: str) -> int:
|
||||||
|
"""Add a string to the StringStore.
|
||||||
|
|
||||||
|
string (str): The string to add.
|
||||||
|
RETURNS (uint64): The string's hash value.
|
||||||
|
"""
|
||||||
|
if not isinstance(string, str):
|
||||||
|
raise TypeError(Errors.E017.format(value_type=type(string)))
|
||||||
|
|
||||||
|
if string in SYMBOLS_BY_STR:
|
||||||
|
return SYMBOLS_BY_STR[string]
|
||||||
|
else:
|
||||||
|
return self._intern_str(string)
|
||||||
|
|
||||||
|
def as_int(self, string_or_hash: Union[str, int]) -> str:
|
||||||
|
"""If a hash value is passed as the input, return it as-is. If the input
|
||||||
|
is a string, return its corresponding hash.
|
||||||
|
|
||||||
|
string_or_hash (str / int): The string to hash or a hash value.
|
||||||
|
RETURNS (int): The hash of the string or the input hash value.
|
||||||
|
"""
|
||||||
|
if isinstance(string_or_hash, int):
|
||||||
|
return string_or_hash
|
||||||
|
else:
|
||||||
|
return get_string_id(string_or_hash)
|
||||||
|
|
||||||
|
def as_string(self, string_or_hash: Union[str, int]) -> str:
|
||||||
|
"""If a string is passed as the input, return it as-is. If the input
|
||||||
|
is a hash value, return its corresponding string.
|
||||||
|
|
||||||
|
string_or_hash (str / int): The hash value to lookup or a string.
|
||||||
|
RETURNS (str): The stored string or the input string.
|
||||||
|
"""
|
||||||
|
if isinstance(string_or_hash, str):
|
||||||
|
return string_or_hash
|
||||||
|
else:
|
||||||
|
return self._get_interned_str(string_or_hash)
|
||||||
|
|
||||||
|
def items(self) -> List[Tuple[str, int]]:
|
||||||
|
"""Iterate over the stored strings and their hashes in insertion order.
|
||||||
|
|
||||||
|
RETURNS: A list of string-hash pairs.
|
||||||
|
"""
|
||||||
|
# Even though we internally store the hashes as keys and the strings as
|
||||||
|
# values, we invert the order in the public API to keep it consistent with
|
||||||
|
# the implementation of the `__iter__` method (where we wish to iterate over
|
||||||
|
# the strings in the store).
|
||||||
|
cdef int i
|
||||||
|
pairs = [None] * self._keys.size()
|
||||||
|
for i in range(self._keys.size()):
|
||||||
|
str_hash = self._keys[i]
|
||||||
|
utf8str = <Utf8Str*>self._map.get(str_hash)
|
||||||
|
pairs[i] = (self._decode_str_repr(utf8str), str_hash)
|
||||||
|
return pairs
|
||||||
|
|
||||||
|
def keys(self) -> List[str]:
|
||||||
|
"""Iterate over the stored strings in insertion order.
|
||||||
|
|
||||||
|
RETURNS: A list of strings.
|
||||||
|
"""
|
||||||
|
cdef int i
|
||||||
|
strings = [None] * self._keys.size()
|
||||||
|
for i in range(self._keys.size()):
|
||||||
|
utf8str = <Utf8Str*>self._map.get(self._keys[i])
|
||||||
|
strings[i] = self._decode_str_repr(utf8str)
|
||||||
|
return strings
|
||||||
|
|
||||||
|
def values(self) -> List[int]:
|
||||||
|
"""Iterate over the stored strings hashes in insertion order.
|
||||||
|
|
||||||
|
RETURNS: A list of string hashs.
|
||||||
|
"""
|
||||||
|
cdef int i
|
||||||
|
hashes = [None] * self._keys.size()
|
||||||
|
for i in range(self._keys.size()):
|
||||||
|
hashes[i] = self._keys[i]
|
||||||
|
return hashes
|
||||||
|
|
||||||
def to_disk(self, path):
|
def to_disk(self, path):
|
||||||
"""Save the current state to a directory.
|
"""Save the current state to a directory.
|
||||||
|
|
||||||
|
@ -294,24 +202,122 @@ cdef class StringStore:
|
||||||
def _reset_and_load(self, strings):
|
def _reset_and_load(self, strings):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self._map = PreshMap()
|
self._map = PreshMap()
|
||||||
self.keys.clear()
|
self._keys.clear()
|
||||||
for string in strings:
|
for string in strings:
|
||||||
self.add(string)
|
self.add(string)
|
||||||
|
|
||||||
cdef const Utf8Str* intern_unicode(self, str py_string):
|
def _get_interned_str(self, hash_value: int) -> str:
|
||||||
# 0 means missing, but we don't bother offsetting the index.
|
cdef hash_t str_hash
|
||||||
cdef bytes byte_string = py_string.encode("utf8")
|
if not _try_coerce_to_hash(hash_value, &str_hash):
|
||||||
return self._intern_utf8(byte_string, len(byte_string), NULL)
|
raise TypeError(Errors.E4001.format(expected_types="'int'", received_type=type(hash_value)))
|
||||||
|
|
||||||
@cython.final
|
# Handle reserved symbols and empty strings correctly.
|
||||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash):
|
if str_hash == 0:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
symbol = SYMBOLS_BY_INT.get(str_hash)
|
||||||
|
if symbol is not None:
|
||||||
|
return symbol
|
||||||
|
|
||||||
|
utf8str = <Utf8Str*>self._map.get(str_hash)
|
||||||
|
if utf8str is NULL:
|
||||||
|
raise KeyError(Errors.E018.format(hash_value=str_hash))
|
||||||
|
else:
|
||||||
|
return self._decode_str_repr(utf8str)
|
||||||
|
|
||||||
|
cdef hash_t _intern_str(self, str string):
|
||||||
# TODO: This function's API/behaviour is an unholy mess...
|
# TODO: This function's API/behaviour is an unholy mess...
|
||||||
# 0 means missing, but we don't bother offsetting the index.
|
# 0 means missing, but we don't bother offsetting the index.
|
||||||
cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else hash_utf8(utf8_string, length)
|
chars = string.encode('utf-8')
|
||||||
|
cdef hash_t key = hash64(<unsigned char*>chars, len(chars), 1)
|
||||||
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
||||||
if value is not NULL:
|
if value is not NULL:
|
||||||
return value
|
return key
|
||||||
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
|
|
||||||
|
value = self._allocate_str_repr(<unsigned char*>chars, len(chars))
|
||||||
self._map.set(key, value)
|
self._map.set(key, value)
|
||||||
self.keys.push_back(key)
|
self._keys.push_back(key)
|
||||||
return value
|
return key
|
||||||
|
|
||||||
|
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *:
|
||||||
|
cdef int n_length_bytes
|
||||||
|
cdef int i
|
||||||
|
cdef Utf8Str* string = <Utf8Str*>self.mem.alloc(1, sizeof(Utf8Str))
|
||||||
|
cdef uint32_t ulength = length
|
||||||
|
if length < sizeof(string.s):
|
||||||
|
string.s[0] = <unsigned char>length
|
||||||
|
memcpy(&string.s[1], chars, length)
|
||||||
|
return string
|
||||||
|
elif length < 255:
|
||||||
|
string.p = <unsigned char*>self.mem.alloc(length + 1, sizeof(unsigned char))
|
||||||
|
string.p[0] = length
|
||||||
|
memcpy(&string.p[1], chars, length)
|
||||||
|
return string
|
||||||
|
else:
|
||||||
|
i = 0
|
||||||
|
n_length_bytes = (length // 255) + 1
|
||||||
|
string.p = <unsigned char*>self.mem.alloc(length + n_length_bytes, sizeof(unsigned char))
|
||||||
|
for i in range(n_length_bytes-1):
|
||||||
|
string.p[i] = 255
|
||||||
|
string.p[n_length_bytes-1] = length % 255
|
||||||
|
memcpy(&string.p[n_length_bytes], chars, length)
|
||||||
|
return string
|
||||||
|
|
||||||
|
cdef str _decode_str_repr(self, const Utf8Str* string):
|
||||||
|
cdef int i, length
|
||||||
|
if string.s[0] < sizeof(string.s) and string.s[0] != 0:
|
||||||
|
return string.s[1:string.s[0]+1].decode('utf-8')
|
||||||
|
elif string.p[0] < 255:
|
||||||
|
return string.p[1:string.p[0]+1].decode('utf-8')
|
||||||
|
else:
|
||||||
|
i = 0
|
||||||
|
length = 0
|
||||||
|
while string.p[i] == 255:
|
||||||
|
i += 1
|
||||||
|
length += 255
|
||||||
|
length += string.p[i]
|
||||||
|
i += 1
|
||||||
|
return string.p[i:length + i].decode('utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
cpdef hash_t hash_string(object string) except -1:
|
||||||
|
if not isinstance(string, str):
|
||||||
|
raise TypeError(Errors.E4001.format(expected_types="'str'", received_type=type(string)))
|
||||||
|
|
||||||
|
# Handle reserved symbols and empty strings correctly.
|
||||||
|
if len(string) == 0:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
symbol = SYMBOLS_BY_STR.get(string)
|
||||||
|
if symbol is not None:
|
||||||
|
return symbol
|
||||||
|
|
||||||
|
chars = string.encode('utf-8')
|
||||||
|
return hash64(<unsigned char*>chars, len(chars), 1)
|
||||||
|
|
||||||
|
|
||||||
|
cpdef hash_t get_string_id(object string_or_hash) except -1:
|
||||||
|
cdef hash_t str_hash
|
||||||
|
|
||||||
|
try:
|
||||||
|
return hash_string(string_or_hash)
|
||||||
|
except:
|
||||||
|
if _try_coerce_to_hash(string_or_hash, &str_hash):
|
||||||
|
# Coerce the integral key to the expected primitive hash type.
|
||||||
|
# This ensures that custom/overloaded "primitive" data types
|
||||||
|
# such as those implemented by numpy are not inadvertently used
|
||||||
|
# downsteam (as these are internally implemented as custom PyObjects
|
||||||
|
# whose comparison operators can incur a significant overhead).
|
||||||
|
return str_hash
|
||||||
|
else:
|
||||||
|
raise TypeError(Errors.E4001.format(expected_types="'str','int'", received_type=type(string_or_hash)))
|
||||||
|
|
||||||
|
|
||||||
|
# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)`
|
||||||
|
cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
|
||||||
|
try:
|
||||||
|
out_hash[0] = key
|
||||||
|
return True
|
||||||
|
except:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
|
@ -24,6 +24,14 @@ def test_stringstore_from_api_docs(stringstore):
|
||||||
stringstore.add("orange")
|
stringstore.add("orange")
|
||||||
all_strings = [s for s in stringstore]
|
all_strings = [s for s in stringstore]
|
||||||
assert all_strings == ["apple", "orange"]
|
assert all_strings == ["apple", "orange"]
|
||||||
|
assert all_strings == list(stringstore.keys())
|
||||||
|
all_strings_and_hashes = list(stringstore.items())
|
||||||
|
assert all_strings_and_hashes == [
|
||||||
|
("apple", 8566208034543834098),
|
||||||
|
("orange", 2208928596161743350),
|
||||||
|
]
|
||||||
|
all_hashes = list(stringstore.values())
|
||||||
|
assert all_hashes == [8566208034543834098, 2208928596161743350]
|
||||||
banana_hash = stringstore.add("banana")
|
banana_hash = stringstore.add("banana")
|
||||||
assert len(stringstore) == 3
|
assert len(stringstore) == 3
|
||||||
assert banana_hash == 2525716904149915114
|
assert banana_hash == 2525716904149915114
|
||||||
|
@ -31,12 +39,25 @@ def test_stringstore_from_api_docs(stringstore):
|
||||||
assert stringstore["banana"] == banana_hash
|
assert stringstore["banana"] == banana_hash
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text1,text2,text3", [(b"Hello", b"goodbye", b"hello")])
|
@pytest.mark.parametrize(
|
||||||
def test_stringstore_save_bytes(stringstore, text1, text2, text3):
|
"val_bytes,val_float,val_list,val_text,val_hash",
|
||||||
key = stringstore.add(text1)
|
[(b"Hello", 1.1, ["abc"], "apple", 8566208034543834098)],
|
||||||
assert stringstore[text1] == key
|
)
|
||||||
assert stringstore[text2] != key
|
def test_stringstore_type_checking(
|
||||||
assert stringstore[text3] != key
|
stringstore, val_bytes, val_float, val_list, val_text, val_hash
|
||||||
|
):
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
assert stringstore[val_bytes]
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
stringstore.add(val_float)
|
||||||
|
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
assert val_list not in stringstore
|
||||||
|
|
||||||
|
key = stringstore.add(val_text)
|
||||||
|
assert val_hash == key
|
||||||
|
assert stringstore[val_hash] == val_text
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text1,text2,text3", [("Hello", "goodbye", "hello")])
|
@pytest.mark.parametrize("text1,text2,text3", [("Hello", "goodbye", "hello")])
|
||||||
|
@ -47,19 +68,19 @@ def test_stringstore_save_unicode(stringstore, text1, text2, text3):
|
||||||
assert stringstore[text3] != key
|
assert stringstore[text3] != key
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text", [b"A"])
|
@pytest.mark.parametrize("text", ["A"])
|
||||||
def test_stringstore_retrieve_id(stringstore, text):
|
def test_stringstore_retrieve_id(stringstore, text):
|
||||||
key = stringstore.add(text)
|
key = stringstore.add(text)
|
||||||
assert len(stringstore) == 1
|
assert len(stringstore) == 1
|
||||||
assert stringstore[key] == text.decode("utf8")
|
assert stringstore[key] == text
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
stringstore[20000]
|
stringstore[20000]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("text1,text2", [(b"0123456789", b"A")])
|
@pytest.mark.parametrize("text1,text2", [("0123456789", "A")])
|
||||||
def test_stringstore_med_string(stringstore, text1, text2):
|
def test_stringstore_med_string(stringstore, text1, text2):
|
||||||
store = stringstore.add(text1)
|
store = stringstore.add(text1)
|
||||||
assert stringstore[store] == text1.decode("utf8")
|
assert stringstore[store] == text1
|
||||||
stringstore.add(text2)
|
stringstore.add(text2)
|
||||||
assert stringstore[text1] == store
|
assert stringstore[text1] == store
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ from murmurhash.mrmr cimport hash64
|
||||||
|
|
||||||
from .. import Errors
|
from .. import Errors
|
||||||
from ..typedefs cimport hash_t
|
from ..typedefs cimport hash_t
|
||||||
from ..strings import get_string_id
|
from ..strings cimport get_string_id
|
||||||
from ..structs cimport EdgeC, GraphC
|
from ..structs cimport EdgeC, GraphC
|
||||||
from .token import Token
|
from .token import Token
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ from .underscore import is_writable_attr
|
||||||
from ..attrs import intify_attrs
|
from ..attrs import intify_attrs
|
||||||
from ..util import SimpleFrozenDict
|
from ..util import SimpleFrozenDict
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..strings import get_string_id
|
from ..strings cimport get_string_id
|
||||||
|
|
||||||
|
|
||||||
cdef class Retokenizer:
|
cdef class Retokenizer:
|
||||||
|
|
|
@ -40,7 +40,8 @@ Get the number of strings in the store.
|
||||||
|
|
||||||
## StringStore.\_\_getitem\_\_ {#getitem tag="method"}
|
## StringStore.\_\_getitem\_\_ {#getitem tag="method"}
|
||||||
|
|
||||||
Retrieve a string from a given hash, or vice versa.
|
Retrieve a string from a given hash. If a string is passed as the input, add it
|
||||||
|
to the store and return its hash.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -51,14 +52,14 @@ Retrieve a string from a given hash, or vice versa.
|
||||||
> assert stringstore[apple_hash] == "apple"
|
> assert stringstore[apple_hash] == "apple"
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| -------------- | ----------------------------------------------- |
|
| ---------------- | ---------------------------------------------------------------------------- |
|
||||||
| `string_or_id` | The value to encode. ~~Union[bytes, str, int]~~ |
|
| `string_or_hash` | The hash value to lookup or the string to store. ~~Union[str, int]~~ |
|
||||||
| **RETURNS** | The value to be retrieved. ~~Union[str, int]~~ |
|
| **RETURNS** | The stored string or the hash of the newly added string. ~~Union[str, int]~~ |
|
||||||
|
|
||||||
## StringStore.\_\_contains\_\_ {#contains tag="method"}
|
## StringStore.\_\_contains\_\_ {#contains tag="method"}
|
||||||
|
|
||||||
Check whether a string is in the store.
|
Check whether a string or a hash is in the store.
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -68,15 +69,14 @@ Check whether a string is in the store.
|
||||||
> assert not "cherry" in stringstore
|
> assert not "cherry" in stringstore
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ----------- | ----------------------------------------------- |
|
| ---------------- | ------------------------------------------------------- |
|
||||||
| `string` | The string to check. ~~str~~ |
|
| `string_or_hash` | The string or hash to check. ~~Union[str, int]~~ |
|
||||||
| **RETURNS** | Whether the store contains the string. ~~bool~~ |
|
| **RETURNS** | Whether the store contains the string or hash. ~~bool~~ |
|
||||||
|
|
||||||
## StringStore.\_\_iter\_\_ {#iter tag="method"}
|
## StringStore.\_\_iter\_\_ {#iter tag="method"}
|
||||||
|
|
||||||
Iterate over the strings in the store, in order. Note that a newly initialized
|
Iterate over the stored strings in insertion order.
|
||||||
store will always include an empty string `""` at position `0`.
|
|
||||||
|
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
|
@ -86,11 +86,59 @@ store will always include an empty string `""` at position `0`.
|
||||||
> assert all_strings == ["apple", "orange"]
|
> assert all_strings == ["apple", "orange"]
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| ---------- | ------------------------------ |
|
| ----------- | ------------------------------ |
|
||||||
| **YIELDS** | A string in the store. ~~str~~ |
|
| **RETURNS** | A string in the store. ~~str~~ |
|
||||||
|
|
||||||
## StringStore.add {#add tag="method" new="2"}
|
## StringStore.items {#iter tag="method" new="4"}
|
||||||
|
|
||||||
|
Iterate over the stored string-hash pairs in insertion order.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> stringstore = StringStore(["apple", "orange"])
|
||||||
|
> all_strings_and_hashes = stringstore.items()
|
||||||
|
> assert all_strings_and_hashes == [("apple", 8566208034543834098), ("orange", 2208928596161743350)]
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ------------------------------------------------------ |
|
||||||
|
| **RETURNS** | A list of string-hash pairs. ~~List[Tuple[str, int]]~~ |
|
||||||
|
|
||||||
|
## StringStore.keys {#iter tag="method" new="4"}
|
||||||
|
|
||||||
|
Iterate over the stored strings in insertion order.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> stringstore = StringStore(["apple", "orange"])
|
||||||
|
> all_strings = stringstore.keys()
|
||||||
|
> assert all_strings == ["apple", "orange"]
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | -------------------------------- |
|
||||||
|
| **RETURNS** | A list of strings. ~~List[str]~~ |
|
||||||
|
|
||||||
|
## StringStore.values {#iter tag="method" new="4"}
|
||||||
|
|
||||||
|
Iterate over the stored string hashes in insertion order.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> stringstore = StringStore(["apple", "orange"])
|
||||||
|
> all_hashes = stringstore.values()
|
||||||
|
> assert all_hashes == [8566208034543834098, 2208928596161743350]
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | -------------------------------------- |
|
||||||
|
| **RETURNS** | A list of string hashes. ~~List[int]~~ |
|
||||||
|
|
||||||
|
## StringStore.add {#add tag="method"}
|
||||||
|
|
||||||
Add a string to the `StringStore`.
|
Add a string to the `StringStore`.
|
||||||
|
|
||||||
|
@ -110,7 +158,7 @@ Add a string to the `StringStore`.
|
||||||
| `string` | The string to add. ~~str~~ |
|
| `string` | The string to add. ~~str~~ |
|
||||||
| **RETURNS** | The string's hash value. ~~int~~ |
|
| **RETURNS** | The string's hash value. ~~int~~ |
|
||||||
|
|
||||||
## StringStore.to_disk {#to_disk tag="method" new="2"}
|
## StringStore.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
Save the current state to a directory.
|
Save the current state to a directory.
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user