strings: Add items() method, add type hints, remove unused methods, restrict inputs to specific types, reorganize methods

This commit is contained in:
shadeMe 2022-08-17 18:24:40 +02:00 committed by shademe
parent 19ba6eca15
commit d6237880b0
4 changed files with 187 additions and 210 deletions

View File

@ -249,7 +249,7 @@ class Errors(metaclass=ErrorsWithCodes):
E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}")
E016 = ("MultitaskObjective target should be function or one of: dep, "
"tag, ent, dep_tag_offset, ent_tag.")
E017 = ("Can only add unicode or bytes. Got type: {value_type}")
E017 = ("Can only add 'str' inputs to StringStore. Got type: {value_type}")
E018 = ("Can't retrieve string for hash '{hash_value}'. This usually "
"refers to an issue with the `Vocab` or `StringStore`.")
E019 = ("Can't create transition with unknown action ID: {action}. Action "
@ -939,9 +939,10 @@ class Errors(metaclass=ErrorsWithCodes):
"`{arg2}`={arg2_values} but these arguments are conflicting.")
E1043 = ("Expected None or a value in range [{range_start}, {range_end}] for entity linker threshold, but got "
"{value}.")
# New errors added in v4.x
E1400 = ("Expected 'str' or 'int', but got '{key_type}'")
E4000 = ("Expected input to be one of the following types: ({expected_types}), "
"but got '{received_type}'")
# Deprecated model shortcuts, only used in errors and warnings

View File

@ -1,4 +1,4 @@
from libc.stdint cimport int64_t
from libc.stdint cimport int64_t, uint32_t
from libcpp.vector cimport vector
from libcpp.set cimport set
from cymem.cymem cimport Pool
@ -7,9 +7,6 @@ from murmurhash.mrmr cimport hash64
from .typedefs cimport attr_t, hash_t
cpdef hash_t hash_string(str string) except 0
ctypedef union Utf8Str:
unsigned char[8] s
unsigned char* p
@ -17,9 +14,13 @@ ctypedef union Utf8Str:
cdef class StringStore:
cdef Pool mem
cdef vector[hash_t] keys
cdef public PreshMap _map
cdef PreshMap key_map
cdef const Utf8Str* intern_unicode(self, str py_string)
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash)
cdef hash_t _intern_str(self, str string)
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *
cdef str _decode_str_repr(self, const Utf8Str* string)
cpdef hash_t hash_string(object string) except -1
cpdef hash_t get_string_id(object string_or_int) except -1

View File

@ -1,21 +1,15 @@
from typing import Optional, Iterable, Iterator, Union, Any, overload
from typing import Optional, Iterable, Iterator, Union, Any, Tuple
from pathlib import Path
def get_string_id(key: Union[str, int]) -> int: ...
class StringStore:
def __init__(
self, strings: Optional[Iterable[str]] = ..., freeze: bool = ...
) -> None: ...
@overload
def __getitem__(self, string_or_id: Union[bytes, str]) -> int: ...
@overload
def __getitem__(self, string_or_id: int) -> str: ...
def as_int(self, key: Union[bytes, str, int]) -> int: ...
def as_string(self, key: Union[bytes, str, int]) -> str: ...
def __init__(self, strings: Optional[Iterable[str]]) -> None: ...
def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]: ...
def as_int(self, key: Union[str, int]) -> int: ...
def as_string(self, string_or_hash: Union[str, int]) -> str: ...
def add(self, string: str) -> int: ...
def items(self) -> Tuple[int, str]: ...
def __len__(self) -> int: ...
def __contains__(self, string: str) -> bool: ...
def __contains__(self, string_or_hash: Union[str, int]) -> bool: ...
def __iter__(self) -> Iterator[str]: ...
def __reduce__(self) -> Any: ...
def to_disk(self, path: Union[str, Path]) -> None: ...

View File

@ -1,4 +1,5 @@
# cython: infer_types=True
from typing import Optional, Union, Iterable, Tuple
cimport cython
from libc.string cimport memcpy
from libcpp.set cimport set
@ -15,176 +16,115 @@ from .errors import Errors
from . import util
def get_string_id(key):
"""Get a string ID, handling the reserved symbols correctly. If the key is
already an ID, return it.
This function optimises for convenience over performance, so shouldn't be
used in tight loops.
"""
cdef hash_t str_hash
if isinstance(key, str):
if len(key) == 0:
return 0
symbol = SYMBOLS_BY_STR.get(key, None)
if symbol is not None:
return symbol
else:
chars = key.encode('utf-8')
return _hash_utf8(chars, len(chars))
elif _try_coerce_to_hash(key, &str_hash):
# Coerce the integral key to the expected primitive hash type.
# This ensures that custom/overloaded "primitive" data types
# such as those implemented by numpy are not inadvertently used
# downsteam (as these are internally implemented as custom PyObjects
# whose comparison operators can incur a significant overhead).
return str_hash
else:
raise KeyError(Errors.E1400.format(key_type=type(key)))
cpdef hash_t hash_string(str string) except 0:
chars = string.encode('utf-8')
return _hash_utf8(chars, len(chars))
cdef class StringStore:
"""Look up strings by 64-bit hashes.
"""Look up strings by 64-bit hashes. Implicitly handles reserved symbols.
DOCS: https://spacy.io/api/stringstore
"""
def __init__(self, strings=None, freeze=False):
def __init__(self, strings: Optional[Iterable[str]]=None):
"""Create the StringStore.
strings (iterable): A sequence of unicode strings to add to the store.
"""
self.mem = Pool()
self._map = PreshMap()
self.key_map = PreshMap()
if strings is not None:
for string in strings:
self.add(string)
def __getitem__(self, object string_or_id):
"""Retrieve a string from a given hash, or vice versa.
def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]:
"""Retrieve a string from a given hash. If a string
is passed as the input, add it to the store and return
its hash.
string_or_id (bytes, str or uint64): The value to encode.
Returns (str / uint64): The value to be retrieved.
string_or_hash (str, int): The hash value to lookup or the string to store.
RETURNS (str, int): The stored string or the hash of the newly added string.
"""
cdef hash_t str_hash
cdef Utf8Str* utf8str = NULL
if isinstance(string_or_id, str):
if len(string_or_id) == 0:
return 0
# Return early if the string is found in the symbols LUT.
symbol = SYMBOLS_BY_STR.get(string_or_id, None)
if symbol is not None:
return symbol
else:
return hash_string(string_or_id)
elif isinstance(string_or_id, bytes):
return _hash_utf8(string_or_id, len(string_or_id))
elif _try_coerce_to_hash(string_or_id, &str_hash):
if str_hash == 0:
return ""
elif str_hash < len(SYMBOLS_BY_INT):
return SYMBOLS_BY_INT[str_hash]
else:
utf8str = <Utf8Str*>self._map.get(str_hash)
if isinstance(string_or_hash, str):
return self.add(string_or_hash)
else:
raise KeyError(Errors.E1400.format(key_type=type(key)))
return self._retrieve_interned_str(string_or_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return _decode_Utf8Str(utf8str)
def __contains__(self, string_or_hash: Union[str, int]) -> bool:
"""Check whether a string or a hash is in the store.
def as_int(self, key):
"""If key is an int, return it; otherwise, get the int value."""
if not isinstance(key, str):
return key
else:
return self[key]
def as_string(self, key):
"""If key is a string, return it; otherwise, get the string value."""
if isinstance(key, str):
return key
else:
return self[key]
def add(self, string):
"""Add a string to the StringStore.
string (str): The string to add.
RETURNS (uint64): The string's hash value.
string (str, int): The string/hash to check.
RETURNS (bool): Whether the store contains the string.
"""
cdef hash_t str_hash
if isinstance(string, str):
if string in SYMBOLS_BY_STR:
return SYMBOLS_BY_STR[string]
string = string.encode('utf-8')
str_hash = _hash_utf8(string, len(string))
self._intern_utf8(string, len(string), &str_hash)
elif isinstance(string, bytes):
if string in SYMBOLS_BY_STR:
return SYMBOLS_BY_STR[string]
str_hash = _hash_utf8(string, len(string))
self._intern_utf8(string, len(string), &str_hash)
cdef hash_t str_hash = get_string_id(string_or_hash)
if str_hash < len(SYMBOLS_BY_INT):
return True
else:
raise TypeError(Errors.E017.format(value_type=type(string)))
return str_hash
return self.key_map.get(str_hash) is not NULL
def __len__(self):
def __iter__(self) -> str:
"""Iterate over the strings in the store, in order.
YIELDS (str): A string in the store.
"""
for _, string in self.items():
yield string
def __reduce__(self):
strings = list(self)
return (StringStore, (strings,), None, None, None)
def __len__(self) -> int:
"""The number of strings in the store.
RETURNS (int): The number of strings in the store.
"""
return self.keys.size()
def __contains__(self, string_or_id not None):
"""Check whether a string or ID is in the store.
def add(self, string: str) -> int:
"""Add a string to the StringStore.
string_or_id (str or int): The string to check.
RETURNS (bool): Whether the store contains the string.
string (str): The string to add.
RETURNS (uint64): The string's hash value.
"""
cdef hash_t str_hash
if isinstance(string_or_id, str):
if len(string_or_id) == 0:
return True
elif string_or_id in SYMBOLS_BY_STR:
return True
str_hash = hash_string(string_or_id)
elif _try_coerce_to_hash(string_or_id, &str_hash):
pass
if not isinstance(string, str):
raise TypeError(Errors.E017.format(value_type=type(string)))
if string in SYMBOLS_BY_STR:
return SYMBOLS_BY_STR[string]
else:
raise KeyError(Errors.E1400.format(key_type=type(string_or_id)))
return self._intern_str(string)
if str_hash < len(SYMBOLS_BY_INT):
return True
def as_int(self, string_or_hash: Union[str, int]) -> str:
"""If a hash value is passed as the input, return it as-is. If the input
is a string, return its corresponding hash.
string_or_hash (str, int): The string to hash or a hash value.
RETURNS (int): The hash of the string or the input hash value.
"""
if isinstance(string_or_hash, int):
return string_or_hash
else:
return self._map.get(str_hash) is not NULL
return get_string_id(string_or_hash)
def __iter__(self):
"""Iterate over the strings in the store, in order.
def as_string(self, string_or_hash: Union[str, int]) -> str:
"""If a string is passed as the input, return it as-is. If the input
is a hash value, return its corresponding string.
YIELDS (str): A string in the store.
string_or_hash (str, int): The hash value to lookup or a string.
RETURNS (str): The stored string or the input string.
"""
if isinstance(string_or_hash, str):
return string_or_hash
else:
return self._retrieve_interned_str(string_or_hash)
def items(self) -> Tuple[int, str]:
"""Iterate over the stored strings and their hashes in order.
YIELDS (int, str): A hash, string pair.
"""
cdef int i
cdef hash_t key
for i in range(self.keys.size()):
key = self.keys[i]
utf8str = <Utf8Str*>self._map.get(key)
yield _decode_Utf8Str(utf8str)
# TODO: Iterate OOV here?
def __reduce__(self):
strings = list(self)
return (StringStore, (strings,), None, None, None)
utf8str = <Utf8Str*>self.key_map.get(key)
yield (key, self._decode_str_repr(utf8str))
def to_disk(self, path):
"""Save the current state to a directory.
@ -234,49 +174,114 @@ cdef class StringStore:
def _reset_and_load(self, strings):
self.mem = Pool()
self._map = PreshMap()
self.key_map = PreshMap()
self.keys.clear()
for string in strings:
self.add(string)
cdef const Utf8Str* intern_unicode(self, str py_string):
# 0 means missing, but we don't bother offsetting the index.
cdef bytes byte_string = py_string.encode('utf-8')
return self._intern_utf8(byte_string, len(byte_string), NULL)
def _retrieve_interned_str(self, hash_value: int) -> str:
cdef hash_t str_hash
if not _try_coerce_to_hash(hash_value, &str_hash):
raise TypeError(Errors.E4000.format(expected_types="'str','int'", received_type=type(hash_value)))
@cython.final
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash):
# Handle reserved symbols and empty strings correctly.
if str_hash == 0:
return ""
elif str_hash < len(SYMBOLS_BY_INT):
return SYMBOLS_BY_INT[str_hash]
utf8str = <Utf8Str*>self.key_map.get(str_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=str_hash))
else:
return self._decode_str_repr(utf8str)
cdef hash_t _intern_str(self, str string):
# TODO: This function's API/behaviour is an unholy mess...
# 0 means missing, but we don't bother offsetting the index.
cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else _hash_utf8(utf8_string, length)
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
chars = string.encode('utf-8')
cdef hash_t key = hash64(<unsigned char*>chars, len(chars), 1)
cdef Utf8Str* value = <Utf8Str*>self.key_map.get(key)
if value is not NULL:
return value
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
self._map.set(key, value)
return key
value = self._allocate_str_repr(<unsigned char*>chars, len(chars))
self.key_map.set(key, value)
self.keys.push_back(key)
return value
return key
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *:
cdef int n_length_bytes
cdef int i
cdef Utf8Str* string = <Utf8Str*>self.mem.alloc(1, sizeof(Utf8Str))
cdef uint32_t ulength = length
if length < sizeof(string.s):
string.s[0] = <unsigned char>length
memcpy(&string.s[1], chars, length)
return string
elif length < 255:
string.p = <unsigned char*>self.mem.alloc(length + 1, sizeof(unsigned char))
string.p[0] = length
memcpy(&string.p[1], chars, length)
return string
else:
i = 0
n_length_bytes = (length // 255) + 1
string.p = <unsigned char*>self.mem.alloc(length + n_length_bytes, sizeof(unsigned char))
for i in range(n_length_bytes-1):
string.p[i] = 255
string.p[n_length_bytes-1] = length % 255
memcpy(&string.p[n_length_bytes], chars, length)
return string
cdef hash_t _hash_utf8(char* utf8_string, int length) nogil:
return hash64(utf8_string, length, 1)
cdef str _decode_Utf8Str(const Utf8Str* string):
cdef int i, length
if string.s[0] < sizeof(string.s) and string.s[0] != 0:
return string.s[1:string.s[0]+1].decode('utf-8')
elif string.p[0] < 255:
return string.p[1:string.p[0]+1].decode('utf-8')
else:
i = 0
length = 0
while string.p[i] == 255:
cdef str _decode_str_repr(self, const Utf8Str* string):
cdef int i, length
if string.s[0] < sizeof(string.s) and string.s[0] != 0:
return string.s[1:string.s[0]+1].decode('utf-8')
elif string.p[0] < 255:
return string.p[1:string.p[0]+1].decode('utf-8')
else:
i = 0
length = 0
while string.p[i] == 255:
i += 1
length += 255
length += string.p[i]
i += 1
length += 255
length += string.p[i]
i += 1
return string.p[i:length + i].decode('utf-8')
return string.p[i:length + i].decode('utf-8')
cpdef hash_t hash_string(object string) except -1:
if not isinstance(string, str):
raise TypeError(Errors.E4000.format(expected_types="'str'", received_type=type(string)))
# Handle reserved symbols and empty strings correctly.
if len(string) == 0:
return 0
symbol = SYMBOLS_BY_STR.get(string, None)
if symbol is not None:
return symbol
chars = string.encode('utf-8')
return hash64(<unsigned char*>chars, len(chars), 1)
cpdef hash_t get_string_id(object string_or_hash) except -1:
cdef hash_t str_hash
try:
return hash_string(string_or_hash)
except:
if _try_coerce_to_hash(string_or_hash, &str_hash):
# Coerce the integral key to the expected primitive hash type.
# This ensures that custom/overloaded "primitive" data types
# such as those implemented by numpy are not inadvertently used
# downsteam (as these are internally implemented as custom PyObjects
# whose comparison operators can incur a significant overhead).
return str_hash
else:
raise TypeError(Errors.E4000.format(expected_types="'str','int'", received_type=type(string_or_hash)))
# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)`
@ -287,27 +292,3 @@ cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
except:
return False
cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *:
cdef int n_length_bytes
cdef int i
cdef Utf8Str* string = <Utf8Str*>mem.alloc(1, sizeof(Utf8Str))
cdef uint32_t ulength = length
if length < sizeof(string.s):
string.s[0] = <unsigned char>length
memcpy(&string.s[1], chars, length)
return string
elif length < 255:
string.p = <unsigned char*>mem.alloc(length + 1, sizeof(unsigned char))
string.p[0] = length
memcpy(&string.p[1], chars, length)
return string
else:
i = 0
n_length_bytes = (length // 255) + 1
string.p = <unsigned char*>mem.alloc(length + n_length_bytes, sizeof(unsigned char))
for i in range(n_length_bytes-1):
string.p[i] = 255
string.p[n_length_bytes-1] = length % 255
memcpy(&string.p[n_length_bytes], chars, length)
return string