strings: Add items() method, add type hints, remove unused methods, restrict inputs to specific types, reorganize methods

This commit is contained in:
shadeMe 2022-08-17 18:24:40 +02:00 committed by shademe
parent 19ba6eca15
commit d6237880b0
4 changed files with 187 additions and 210 deletions

View File

@ -249,7 +249,7 @@ class Errors(metaclass=ErrorsWithCodes):
E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}") E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}")
E016 = ("MultitaskObjective target should be function or one of: dep, " E016 = ("MultitaskObjective target should be function or one of: dep, "
"tag, ent, dep_tag_offset, ent_tag.") "tag, ent, dep_tag_offset, ent_tag.")
E017 = ("Can only add unicode or bytes. Got type: {value_type}") E017 = ("Can only add 'str' inputs to StringStore. Got type: {value_type}")
E018 = ("Can't retrieve string for hash '{hash_value}'. This usually " E018 = ("Can't retrieve string for hash '{hash_value}'. This usually "
"refers to an issue with the `Vocab` or `StringStore`.") "refers to an issue with the `Vocab` or `StringStore`.")
E019 = ("Can't create transition with unknown action ID: {action}. Action " E019 = ("Can't create transition with unknown action ID: {action}. Action "
@ -941,7 +941,8 @@ class Errors(metaclass=ErrorsWithCodes):
"{value}.") "{value}.")
# New errors added in v4.x # New errors added in v4.x
E1400 = ("Expected 'str' or 'int', but got '{key_type}'") E4000 = ("Expected input to be one of the following types: ({expected_types}), "
"but got '{received_type}'")
# Deprecated model shortcuts, only used in errors and warnings # Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,4 +1,4 @@
from libc.stdint cimport int64_t from libc.stdint cimport int64_t, uint32_t
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libcpp.set cimport set from libcpp.set cimport set
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
@ -7,9 +7,6 @@ from murmurhash.mrmr cimport hash64
from .typedefs cimport attr_t, hash_t from .typedefs cimport attr_t, hash_t
cpdef hash_t hash_string(str string) except 0
ctypedef union Utf8Str: ctypedef union Utf8Str:
unsigned char[8] s unsigned char[8] s
unsigned char* p unsigned char* p
@ -17,9 +14,13 @@ ctypedef union Utf8Str:
cdef class StringStore: cdef class StringStore:
cdef Pool mem cdef Pool mem
cdef vector[hash_t] keys cdef vector[hash_t] keys
cdef public PreshMap _map cdef PreshMap key_map
cdef const Utf8Str* intern_unicode(self, str py_string) cdef hash_t _intern_str(self, str string)
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash) cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *
cdef str _decode_str_repr(self, const Utf8Str* string)
cpdef hash_t hash_string(object string) except -1
cpdef hash_t get_string_id(object string_or_int) except -1

View File

@ -1,21 +1,15 @@
from typing import Optional, Iterable, Iterator, Union, Any, overload from typing import Optional, Iterable, Iterator, Union, Any, Tuple
from pathlib import Path from pathlib import Path
def get_string_id(key: Union[str, int]) -> int: ...
class StringStore: class StringStore:
def __init__( def __init__(self, strings: Optional[Iterable[str]]) -> None: ...
self, strings: Optional[Iterable[str]] = ..., freeze: bool = ... def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]: ...
) -> None: ... def as_int(self, key: Union[str, int]) -> int: ...
@overload def as_string(self, string_or_hash: Union[str, int]) -> str: ...
def __getitem__(self, string_or_id: Union[bytes, str]) -> int: ...
@overload
def __getitem__(self, string_or_id: int) -> str: ...
def as_int(self, key: Union[bytes, str, int]) -> int: ...
def as_string(self, key: Union[bytes, str, int]) -> str: ...
def add(self, string: str) -> int: ... def add(self, string: str) -> int: ...
def items(self) -> Tuple[int, str]: ...
def __len__(self) -> int: ... def __len__(self) -> int: ...
def __contains__(self, string: str) -> bool: ... def __contains__(self, string_or_hash: Union[str, int]) -> bool: ...
def __iter__(self) -> Iterator[str]: ... def __iter__(self) -> Iterator[str]: ...
def __reduce__(self) -> Any: ... def __reduce__(self) -> Any: ...
def to_disk(self, path: Union[str, Path]) -> None: ... def to_disk(self, path: Union[str, Path]) -> None: ...

View File

@ -1,4 +1,5 @@
# cython: infer_types=True # cython: infer_types=True
from typing import Optional, Union, Iterable, Tuple
cimport cython cimport cython
from libc.string cimport memcpy from libc.string cimport memcpy
from libcpp.set cimport set from libcpp.set cimport set
@ -15,176 +16,115 @@ from .errors import Errors
from . import util from . import util
def get_string_id(key):
"""Get a string ID, handling the reserved symbols correctly. If the key is
already an ID, return it.
This function optimises for convenience over performance, so shouldn't be
used in tight loops.
"""
cdef hash_t str_hash
if isinstance(key, str):
if len(key) == 0:
return 0
symbol = SYMBOLS_BY_STR.get(key, None)
if symbol is not None:
return symbol
else:
chars = key.encode('utf-8')
return _hash_utf8(chars, len(chars))
elif _try_coerce_to_hash(key, &str_hash):
# Coerce the integral key to the expected primitive hash type.
# This ensures that custom/overloaded "primitive" data types
# such as those implemented by numpy are not inadvertently used
# downsteam (as these are internally implemented as custom PyObjects
# whose comparison operators can incur a significant overhead).
return str_hash
else:
raise KeyError(Errors.E1400.format(key_type=type(key)))
cpdef hash_t hash_string(str string) except 0:
chars = string.encode('utf-8')
return _hash_utf8(chars, len(chars))
cdef class StringStore: cdef class StringStore:
"""Look up strings by 64-bit hashes. """Look up strings by 64-bit hashes. Implicitly handles reserved symbols.
DOCS: https://spacy.io/api/stringstore DOCS: https://spacy.io/api/stringstore
""" """
def __init__(self, strings=None, freeze=False): def __init__(self, strings: Optional[Iterable[str]]=None):
"""Create the StringStore. """Create the StringStore.
strings (iterable): A sequence of unicode strings to add to the store. strings (iterable): A sequence of unicode strings to add to the store.
""" """
self.mem = Pool() self.mem = Pool()
self._map = PreshMap() self.key_map = PreshMap()
if strings is not None: if strings is not None:
for string in strings: for string in strings:
self.add(string) self.add(string)
def __getitem__(self, object string_or_id): def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]:
"""Retrieve a string from a given hash, or vice versa. """Retrieve a string from a given hash. If a string
is passed as the input, add it to the store and return
its hash.
string_or_id (bytes, str or uint64): The value to encode. string_or_hash (str, int): The hash value to lookup or the string to store.
Returns (str / uint64): The value to be retrieved. RETURNS (str, int): The stored string or the hash of the newly added string.
""" """
cdef hash_t str_hash if isinstance(string_or_hash, str):
cdef Utf8Str* utf8str = NULL return self.add(string_or_hash)
if isinstance(string_or_id, str):
if len(string_or_id) == 0:
return 0
# Return early if the string is found in the symbols LUT.
symbol = SYMBOLS_BY_STR.get(string_or_id, None)
if symbol is not None:
return symbol
else: else:
return hash_string(string_or_id) return self._retrieve_interned_str(string_or_hash)
elif isinstance(string_or_id, bytes):
return _hash_utf8(string_or_id, len(string_or_id))
elif _try_coerce_to_hash(string_or_id, &str_hash):
if str_hash == 0:
return ""
elif str_hash < len(SYMBOLS_BY_INT):
return SYMBOLS_BY_INT[str_hash]
else:
utf8str = <Utf8Str*>self._map.get(str_hash)
else:
raise KeyError(Errors.E1400.format(key_type=type(key)))
if utf8str is NULL: def __contains__(self, string_or_hash: Union[str, int]) -> bool:
raise KeyError(Errors.E018.format(hash_value=string_or_id)) """Check whether a string or a hash is in the store.
else:
return _decode_Utf8Str(utf8str)
def as_int(self, key): string (str, int): The string/hash to check.
"""If key is an int, return it; otherwise, get the int value.""" RETURNS (bool): Whether the store contains the string.
if not isinstance(key, str):
return key
else:
return self[key]
def as_string(self, key):
"""If key is a string, return it; otherwise, get the string value."""
if isinstance(key, str):
return key
else:
return self[key]
def add(self, string):
"""Add a string to the StringStore.
string (str): The string to add.
RETURNS (uint64): The string's hash value.
""" """
cdef hash_t str_hash cdef hash_t str_hash = get_string_id(string_or_hash)
if isinstance(string, str): if str_hash < len(SYMBOLS_BY_INT):
if string in SYMBOLS_BY_STR: return True
return SYMBOLS_BY_STR[string]
string = string.encode('utf-8')
str_hash = _hash_utf8(string, len(string))
self._intern_utf8(string, len(string), &str_hash)
elif isinstance(string, bytes):
if string in SYMBOLS_BY_STR:
return SYMBOLS_BY_STR[string]
str_hash = _hash_utf8(string, len(string))
self._intern_utf8(string, len(string), &str_hash)
else: else:
raise TypeError(Errors.E017.format(value_type=type(string))) return self.key_map.get(str_hash) is not NULL
return str_hash
def __len__(self): def __iter__(self) -> str:
"""Iterate over the strings in the store, in order.
YIELDS (str): A string in the store.
"""
for _, string in self.items():
yield string
def __reduce__(self):
strings = list(self)
return (StringStore, (strings,), None, None, None)
def __len__(self) -> int:
"""The number of strings in the store. """The number of strings in the store.
RETURNS (int): The number of strings in the store. RETURNS (int): The number of strings in the store.
""" """
return self.keys.size() return self.keys.size()
def __contains__(self, string_or_id not None): def add(self, string: str) -> int:
"""Check whether a string or ID is in the store. """Add a string to the StringStore.
string_or_id (str or int): The string to check. string (str): The string to add.
RETURNS (bool): Whether the store contains the string. RETURNS (uint64): The string's hash value.
""" """
cdef hash_t str_hash if not isinstance(string, str):
if isinstance(string_or_id, str): raise TypeError(Errors.E017.format(value_type=type(string)))
if len(string_or_id) == 0:
return True if string in SYMBOLS_BY_STR:
elif string_or_id in SYMBOLS_BY_STR: return SYMBOLS_BY_STR[string]
return True
str_hash = hash_string(string_or_id)
elif _try_coerce_to_hash(string_or_id, &str_hash):
pass
else: else:
raise KeyError(Errors.E1400.format(key_type=type(string_or_id))) return self._intern_str(string)
if str_hash < len(SYMBOLS_BY_INT): def as_int(self, string_or_hash: Union[str, int]) -> str:
return True """If a hash value is passed as the input, return it as-is. If the input
is a string, return its corresponding hash.
string_or_hash (str, int): The string to hash or a hash value.
RETURNS (int): The hash of the string or the input hash value.
"""
if isinstance(string_or_hash, int):
return string_or_hash
else: else:
return self._map.get(str_hash) is not NULL return get_string_id(string_or_hash)
def __iter__(self): def as_string(self, string_or_hash: Union[str, int]) -> str:
"""Iterate over the strings in the store, in order. """If a string is passed as the input, return it as-is. If the input
is a hash value, return its corresponding string.
YIELDS (str): A string in the store. string_or_hash (str, int): The hash value to lookup or a string.
RETURNS (str): The stored string or the input string.
"""
if isinstance(string_or_hash, str):
return string_or_hash
else:
return self._retrieve_interned_str(string_or_hash)
def items(self) -> Tuple[int, str]:
"""Iterate over the stored strings and their hashes in order.
YIELDS (int, str): A hash, string pair.
""" """
cdef int i cdef int i
cdef hash_t key cdef hash_t key
for i in range(self.keys.size()): for i in range(self.keys.size()):
key = self.keys[i] key = self.keys[i]
utf8str = <Utf8Str*>self._map.get(key) utf8str = <Utf8Str*>self.key_map.get(key)
yield _decode_Utf8Str(utf8str) yield (key, self._decode_str_repr(utf8str))
# TODO: Iterate OOV here?
def __reduce__(self):
strings = list(self)
return (StringStore, (strings,), None, None, None)
def to_disk(self, path): def to_disk(self, path):
"""Save the current state to a directory. """Save the current state to a directory.
@ -234,35 +174,67 @@ cdef class StringStore:
def _reset_and_load(self, strings): def _reset_and_load(self, strings):
self.mem = Pool() self.mem = Pool()
self._map = PreshMap() self.key_map = PreshMap()
self.keys.clear() self.keys.clear()
for string in strings: for string in strings:
self.add(string) self.add(string)
cdef const Utf8Str* intern_unicode(self, str py_string): def _retrieve_interned_str(self, hash_value: int) -> str:
# 0 means missing, but we don't bother offsetting the index. cdef hash_t str_hash
cdef bytes byte_string = py_string.encode('utf-8') if not _try_coerce_to_hash(hash_value, &str_hash):
return self._intern_utf8(byte_string, len(byte_string), NULL) raise TypeError(Errors.E4000.format(expected_types="'str','int'", received_type=type(hash_value)))
@cython.final # Handle reserved symbols and empty strings correctly.
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash): if str_hash == 0:
return ""
elif str_hash < len(SYMBOLS_BY_INT):
return SYMBOLS_BY_INT[str_hash]
utf8str = <Utf8Str*>self.key_map.get(str_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=str_hash))
else:
return self._decode_str_repr(utf8str)
cdef hash_t _intern_str(self, str string):
# TODO: This function's API/behaviour is an unholy mess... # TODO: This function's API/behaviour is an unholy mess...
# 0 means missing, but we don't bother offsetting the index. # 0 means missing, but we don't bother offsetting the index.
cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else _hash_utf8(utf8_string, length) chars = string.encode('utf-8')
cdef Utf8Str* value = <Utf8Str*>self._map.get(key) cdef hash_t key = hash64(<unsigned char*>chars, len(chars), 1)
cdef Utf8Str* value = <Utf8Str*>self.key_map.get(key)
if value is not NULL: if value is not NULL:
return value return key
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
self._map.set(key, value) value = self._allocate_str_repr(<unsigned char*>chars, len(chars))
self.key_map.set(key, value)
self.keys.push_back(key) self.keys.push_back(key)
return value return key
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *:
cdef int n_length_bytes
cdef int i
cdef Utf8Str* string = <Utf8Str*>self.mem.alloc(1, sizeof(Utf8Str))
cdef uint32_t ulength = length
if length < sizeof(string.s):
string.s[0] = <unsigned char>length
memcpy(&string.s[1], chars, length)
return string
elif length < 255:
string.p = <unsigned char*>self.mem.alloc(length + 1, sizeof(unsigned char))
string.p[0] = length
memcpy(&string.p[1], chars, length)
return string
else:
i = 0
n_length_bytes = (length // 255) + 1
string.p = <unsigned char*>self.mem.alloc(length + n_length_bytes, sizeof(unsigned char))
for i in range(n_length_bytes-1):
string.p[i] = 255
string.p[n_length_bytes-1] = length % 255
memcpy(&string.p[n_length_bytes], chars, length)
return string
cdef hash_t _hash_utf8(char* utf8_string, int length) nogil: cdef str _decode_str_repr(self, const Utf8Str* string):
return hash64(utf8_string, length, 1)
cdef str _decode_Utf8Str(const Utf8Str* string):
cdef int i, length cdef int i, length
if string.s[0] < sizeof(string.s) and string.s[0] != 0: if string.s[0] < sizeof(string.s) and string.s[0] != 0:
return string.s[1:string.s[0]+1].decode('utf-8') return string.s[1:string.s[0]+1].decode('utf-8')
@ -279,6 +251,39 @@ cdef str _decode_Utf8Str(const Utf8Str* string):
return string.p[i:length + i].decode('utf-8') return string.p[i:length + i].decode('utf-8')
cpdef hash_t hash_string(object string) except -1:
if not isinstance(string, str):
raise TypeError(Errors.E4000.format(expected_types="'str'", received_type=type(string)))
# Handle reserved symbols and empty strings correctly.
if len(string) == 0:
return 0
symbol = SYMBOLS_BY_STR.get(string, None)
if symbol is not None:
return symbol
chars = string.encode('utf-8')
return hash64(<unsigned char*>chars, len(chars), 1)
cpdef hash_t get_string_id(object string_or_hash) except -1:
cdef hash_t str_hash
try:
return hash_string(string_or_hash)
except:
if _try_coerce_to_hash(string_or_hash, &str_hash):
# Coerce the integral key to the expected primitive hash type.
# This ensures that custom/overloaded "primitive" data types
# such as those implemented by numpy are not inadvertently used
# downsteam (as these are internally implemented as custom PyObjects
# whose comparison operators can incur a significant overhead).
return str_hash
else:
raise TypeError(Errors.E4000.format(expected_types="'str','int'", received_type=type(string_or_hash)))
# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)` # Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)`
cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash): cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
try: try:
@ -287,27 +292,3 @@ cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
except: except:
return False return False
cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *:
cdef int n_length_bytes
cdef int i
cdef Utf8Str* string = <Utf8Str*>mem.alloc(1, sizeof(Utf8Str))
cdef uint32_t ulength = length
if length < sizeof(string.s):
string.s[0] = <unsigned char>length
memcpy(&string.s[1], chars, length)
return string
elif length < 255:
string.p = <unsigned char*>mem.alloc(length + 1, sizeof(unsigned char))
string.p[0] = length
memcpy(&string.p[1], chars, length)
return string
else:
i = 0
n_length_bytes = (length // 255) + 1
string.p = <unsigned char*>mem.alloc(length + n_length_bytes, sizeof(unsigned char))
for i in range(n_length_bytes-1):
string.p[i] = 255
string.p[n_length_bytes-1] = length % 255
memcpy(&string.p[n_length_bytes], chars, length)
return string