mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-15 11:00:34 +03:00
strings
: Add items()
method, add type hints, remove unused methods, restrict inputs to specific types, reorganize methods
This commit is contained in:
parent
19ba6eca15
commit
d6237880b0
|
@ -249,7 +249,7 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
E012 = ("Cannot add pattern for zero tokens to matcher.\nKey: {key}")
|
||||
E016 = ("MultitaskObjective target should be function or one of: dep, "
|
||||
"tag, ent, dep_tag_offset, ent_tag.")
|
||||
E017 = ("Can only add unicode or bytes. Got type: {value_type}")
|
||||
E017 = ("Can only add 'str' inputs to StringStore. Got type: {value_type}")
|
||||
E018 = ("Can't retrieve string for hash '{hash_value}'. This usually "
|
||||
"refers to an issue with the `Vocab` or `StringStore`.")
|
||||
E019 = ("Can't create transition with unknown action ID: {action}. Action "
|
||||
|
@ -941,7 +941,8 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"{value}.")
|
||||
|
||||
# New errors added in v4.x
|
||||
E1400 = ("Expected 'str' or 'int', but got '{key_type}'")
|
||||
E4000 = ("Expected input to be one of the following types: ({expected_types}), "
|
||||
"but got '{received_type}'")
|
||||
|
||||
|
||||
# Deprecated model shortcuts, only used in errors and warnings
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from libc.stdint cimport int64_t
|
||||
from libc.stdint cimport int64_t, uint32_t
|
||||
from libcpp.vector cimport vector
|
||||
from libcpp.set cimport set
|
||||
from cymem.cymem cimport Pool
|
||||
|
@ -7,9 +7,6 @@ from murmurhash.mrmr cimport hash64
|
|||
|
||||
from .typedefs cimport attr_t, hash_t
|
||||
|
||||
|
||||
cpdef hash_t hash_string(str string) except 0
|
||||
|
||||
ctypedef union Utf8Str:
|
||||
unsigned char[8] s
|
||||
unsigned char* p
|
||||
|
@ -17,9 +14,13 @@ ctypedef union Utf8Str:
|
|||
|
||||
cdef class StringStore:
|
||||
cdef Pool mem
|
||||
|
||||
cdef vector[hash_t] keys
|
||||
cdef public PreshMap _map
|
||||
cdef PreshMap key_map
|
||||
|
||||
cdef const Utf8Str* intern_unicode(self, str py_string)
|
||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash)
|
||||
cdef hash_t _intern_str(self, str string)
|
||||
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *
|
||||
cdef str _decode_str_repr(self, const Utf8Str* string)
|
||||
|
||||
|
||||
cpdef hash_t hash_string(object string) except -1
|
||||
cpdef hash_t get_string_id(object string_or_int) except -1
|
||||
|
|
|
@ -1,21 +1,15 @@
|
|||
from typing import Optional, Iterable, Iterator, Union, Any, overload
|
||||
from typing import Optional, Iterable, Iterator, Union, Any, Tuple
|
||||
from pathlib import Path
|
||||
|
||||
def get_string_id(key: Union[str, int]) -> int: ...
|
||||
|
||||
class StringStore:
|
||||
def __init__(
|
||||
self, strings: Optional[Iterable[str]] = ..., freeze: bool = ...
|
||||
) -> None: ...
|
||||
@overload
|
||||
def __getitem__(self, string_or_id: Union[bytes, str]) -> int: ...
|
||||
@overload
|
||||
def __getitem__(self, string_or_id: int) -> str: ...
|
||||
def as_int(self, key: Union[bytes, str, int]) -> int: ...
|
||||
def as_string(self, key: Union[bytes, str, int]) -> str: ...
|
||||
def __init__(self, strings: Optional[Iterable[str]]) -> None: ...
|
||||
def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]: ...
|
||||
def as_int(self, key: Union[str, int]) -> int: ...
|
||||
def as_string(self, string_or_hash: Union[str, int]) -> str: ...
|
||||
def add(self, string: str) -> int: ...
|
||||
def items(self) -> Tuple[int, str]: ...
|
||||
def __len__(self) -> int: ...
|
||||
def __contains__(self, string: str) -> bool: ...
|
||||
def __contains__(self, string_or_hash: Union[str, int]) -> bool: ...
|
||||
def __iter__(self) -> Iterator[str]: ...
|
||||
def __reduce__(self) -> Any: ...
|
||||
def to_disk(self, path: Union[str, Path]) -> None: ...
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
# cython: infer_types=True
|
||||
from typing import Optional, Union, Iterable, Tuple
|
||||
cimport cython
|
||||
from libc.string cimport memcpy
|
||||
from libcpp.set cimport set
|
||||
|
@ -15,176 +16,115 @@ from .errors import Errors
|
|||
from . import util
|
||||
|
||||
|
||||
|
||||
def get_string_id(key):
|
||||
"""Get a string ID, handling the reserved symbols correctly. If the key is
|
||||
already an ID, return it.
|
||||
|
||||
This function optimises for convenience over performance, so shouldn't be
|
||||
used in tight loops.
|
||||
"""
|
||||
cdef hash_t str_hash
|
||||
if isinstance(key, str):
|
||||
if len(key) == 0:
|
||||
return 0
|
||||
|
||||
symbol = SYMBOLS_BY_STR.get(key, None)
|
||||
if symbol is not None:
|
||||
return symbol
|
||||
else:
|
||||
chars = key.encode('utf-8')
|
||||
return _hash_utf8(chars, len(chars))
|
||||
elif _try_coerce_to_hash(key, &str_hash):
|
||||
# Coerce the integral key to the expected primitive hash type.
|
||||
# This ensures that custom/overloaded "primitive" data types
|
||||
# such as those implemented by numpy are not inadvertently used
|
||||
# downsteam (as these are internally implemented as custom PyObjects
|
||||
# whose comparison operators can incur a significant overhead).
|
||||
return str_hash
|
||||
else:
|
||||
raise KeyError(Errors.E1400.format(key_type=type(key)))
|
||||
|
||||
|
||||
cpdef hash_t hash_string(str string) except 0:
|
||||
chars = string.encode('utf-8')
|
||||
return _hash_utf8(chars, len(chars))
|
||||
|
||||
|
||||
cdef class StringStore:
|
||||
"""Look up strings by 64-bit hashes.
|
||||
"""Look up strings by 64-bit hashes. Implicitly handles reserved symbols.
|
||||
|
||||
DOCS: https://spacy.io/api/stringstore
|
||||
"""
|
||||
def __init__(self, strings=None, freeze=False):
|
||||
def __init__(self, strings: Optional[Iterable[str]]=None):
|
||||
"""Create the StringStore.
|
||||
|
||||
strings (iterable): A sequence of unicode strings to add to the store.
|
||||
"""
|
||||
self.mem = Pool()
|
||||
self._map = PreshMap()
|
||||
self.key_map = PreshMap()
|
||||
if strings is not None:
|
||||
for string in strings:
|
||||
self.add(string)
|
||||
|
||||
def __getitem__(self, object string_or_id):
|
||||
"""Retrieve a string from a given hash, or vice versa.
|
||||
def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]:
|
||||
"""Retrieve a string from a given hash. If a string
|
||||
is passed as the input, add it to the store and return
|
||||
its hash.
|
||||
|
||||
string_or_id (bytes, str or uint64): The value to encode.
|
||||
Returns (str / uint64): The value to be retrieved.
|
||||
string_or_hash (str, int): The hash value to lookup or the string to store.
|
||||
RETURNS (str, int): The stored string or the hash of the newly added string.
|
||||
"""
|
||||
cdef hash_t str_hash
|
||||
cdef Utf8Str* utf8str = NULL
|
||||
|
||||
if isinstance(string_or_id, str):
|
||||
if len(string_or_id) == 0:
|
||||
return 0
|
||||
|
||||
# Return early if the string is found in the symbols LUT.
|
||||
symbol = SYMBOLS_BY_STR.get(string_or_id, None)
|
||||
if symbol is not None:
|
||||
return symbol
|
||||
if isinstance(string_or_hash, str):
|
||||
return self.add(string_or_hash)
|
||||
else:
|
||||
return hash_string(string_or_id)
|
||||
elif isinstance(string_or_id, bytes):
|
||||
return _hash_utf8(string_or_id, len(string_or_id))
|
||||
elif _try_coerce_to_hash(string_or_id, &str_hash):
|
||||
if str_hash == 0:
|
||||
return ""
|
||||
elif str_hash < len(SYMBOLS_BY_INT):
|
||||
return SYMBOLS_BY_INT[str_hash]
|
||||
else:
|
||||
utf8str = <Utf8Str*>self._map.get(str_hash)
|
||||
else:
|
||||
raise KeyError(Errors.E1400.format(key_type=type(key)))
|
||||
return self._retrieve_interned_str(string_or_hash)
|
||||
|
||||
if utf8str is NULL:
|
||||
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
||||
else:
|
||||
return _decode_Utf8Str(utf8str)
|
||||
def __contains__(self, string_or_hash: Union[str, int]) -> bool:
|
||||
"""Check whether a string or a hash is in the store.
|
||||
|
||||
def as_int(self, key):
|
||||
"""If key is an int, return it; otherwise, get the int value."""
|
||||
if not isinstance(key, str):
|
||||
return key
|
||||
else:
|
||||
return self[key]
|
||||
|
||||
def as_string(self, key):
|
||||
"""If key is a string, return it; otherwise, get the string value."""
|
||||
if isinstance(key, str):
|
||||
return key
|
||||
else:
|
||||
return self[key]
|
||||
|
||||
def add(self, string):
|
||||
"""Add a string to the StringStore.
|
||||
|
||||
string (str): The string to add.
|
||||
RETURNS (uint64): The string's hash value.
|
||||
string (str, int): The string/hash to check.
|
||||
RETURNS (bool): Whether the store contains the string.
|
||||
"""
|
||||
cdef hash_t str_hash
|
||||
if isinstance(string, str):
|
||||
if string in SYMBOLS_BY_STR:
|
||||
return SYMBOLS_BY_STR[string]
|
||||
|
||||
string = string.encode('utf-8')
|
||||
str_hash = _hash_utf8(string, len(string))
|
||||
self._intern_utf8(string, len(string), &str_hash)
|
||||
elif isinstance(string, bytes):
|
||||
if string in SYMBOLS_BY_STR:
|
||||
return SYMBOLS_BY_STR[string]
|
||||
str_hash = _hash_utf8(string, len(string))
|
||||
self._intern_utf8(string, len(string), &str_hash)
|
||||
cdef hash_t str_hash = get_string_id(string_or_hash)
|
||||
if str_hash < len(SYMBOLS_BY_INT):
|
||||
return True
|
||||
else:
|
||||
raise TypeError(Errors.E017.format(value_type=type(string)))
|
||||
return str_hash
|
||||
return self.key_map.get(str_hash) is not NULL
|
||||
|
||||
def __len__(self):
|
||||
def __iter__(self) -> str:
|
||||
"""Iterate over the strings in the store, in order.
|
||||
|
||||
YIELDS (str): A string in the store.
|
||||
"""
|
||||
for _, string in self.items():
|
||||
yield string
|
||||
|
||||
def __reduce__(self):
|
||||
strings = list(self)
|
||||
return (StringStore, (strings,), None, None, None)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of strings in the store.
|
||||
|
||||
RETURNS (int): The number of strings in the store.
|
||||
"""
|
||||
return self.keys.size()
|
||||
|
||||
def __contains__(self, string_or_id not None):
|
||||
"""Check whether a string or ID is in the store.
|
||||
def add(self, string: str) -> int:
|
||||
"""Add a string to the StringStore.
|
||||
|
||||
string_or_id (str or int): The string to check.
|
||||
RETURNS (bool): Whether the store contains the string.
|
||||
string (str): The string to add.
|
||||
RETURNS (uint64): The string's hash value.
|
||||
"""
|
||||
cdef hash_t str_hash
|
||||
if isinstance(string_or_id, str):
|
||||
if len(string_or_id) == 0:
|
||||
return True
|
||||
elif string_or_id in SYMBOLS_BY_STR:
|
||||
return True
|
||||
str_hash = hash_string(string_or_id)
|
||||
elif _try_coerce_to_hash(string_or_id, &str_hash):
|
||||
pass
|
||||
if not isinstance(string, str):
|
||||
raise TypeError(Errors.E017.format(value_type=type(string)))
|
||||
|
||||
if string in SYMBOLS_BY_STR:
|
||||
return SYMBOLS_BY_STR[string]
|
||||
else:
|
||||
raise KeyError(Errors.E1400.format(key_type=type(string_or_id)))
|
||||
return self._intern_str(string)
|
||||
|
||||
if str_hash < len(SYMBOLS_BY_INT):
|
||||
return True
|
||||
def as_int(self, string_or_hash: Union[str, int]) -> str:
|
||||
"""If a hash value is passed as the input, return it as-is. If the input
|
||||
is a string, return its corresponding hash.
|
||||
|
||||
string_or_hash (str, int): The string to hash or a hash value.
|
||||
RETURNS (int): The hash of the string or the input hash value.
|
||||
"""
|
||||
if isinstance(string_or_hash, int):
|
||||
return string_or_hash
|
||||
else:
|
||||
return self._map.get(str_hash) is not NULL
|
||||
return get_string_id(string_or_hash)
|
||||
|
||||
def __iter__(self):
|
||||
"""Iterate over the strings in the store, in order.
|
||||
def as_string(self, string_or_hash: Union[str, int]) -> str:
|
||||
"""If a string is passed as the input, return it as-is. If the input
|
||||
is a hash value, return its corresponding string.
|
||||
|
||||
YIELDS (str): A string in the store.
|
||||
string_or_hash (str, int): The hash value to lookup or a string.
|
||||
RETURNS (str): The stored string or the input string.
|
||||
"""
|
||||
if isinstance(string_or_hash, str):
|
||||
return string_or_hash
|
||||
else:
|
||||
return self._retrieve_interned_str(string_or_hash)
|
||||
|
||||
def items(self) -> Tuple[int, str]:
|
||||
"""Iterate over the stored strings and their hashes in order.
|
||||
|
||||
YIELDS (int, str): A hash, string pair.
|
||||
"""
|
||||
cdef int i
|
||||
cdef hash_t key
|
||||
for i in range(self.keys.size()):
|
||||
key = self.keys[i]
|
||||
utf8str = <Utf8Str*>self._map.get(key)
|
||||
yield _decode_Utf8Str(utf8str)
|
||||
# TODO: Iterate OOV here?
|
||||
|
||||
def __reduce__(self):
|
||||
strings = list(self)
|
||||
return (StringStore, (strings,), None, None, None)
|
||||
utf8str = <Utf8Str*>self.key_map.get(key)
|
||||
yield (key, self._decode_str_repr(utf8str))
|
||||
|
||||
def to_disk(self, path):
|
||||
"""Save the current state to a directory.
|
||||
|
@ -234,35 +174,67 @@ cdef class StringStore:
|
|||
|
||||
def _reset_and_load(self, strings):
|
||||
self.mem = Pool()
|
||||
self._map = PreshMap()
|
||||
self.key_map = PreshMap()
|
||||
self.keys.clear()
|
||||
for string in strings:
|
||||
self.add(string)
|
||||
|
||||
cdef const Utf8Str* intern_unicode(self, str py_string):
|
||||
# 0 means missing, but we don't bother offsetting the index.
|
||||
cdef bytes byte_string = py_string.encode('utf-8')
|
||||
return self._intern_utf8(byte_string, len(byte_string), NULL)
|
||||
def _retrieve_interned_str(self, hash_value: int) -> str:
|
||||
cdef hash_t str_hash
|
||||
if not _try_coerce_to_hash(hash_value, &str_hash):
|
||||
raise TypeError(Errors.E4000.format(expected_types="'str','int'", received_type=type(hash_value)))
|
||||
|
||||
@cython.final
|
||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash):
|
||||
# Handle reserved symbols and empty strings correctly.
|
||||
if str_hash == 0:
|
||||
return ""
|
||||
elif str_hash < len(SYMBOLS_BY_INT):
|
||||
return SYMBOLS_BY_INT[str_hash]
|
||||
|
||||
utf8str = <Utf8Str*>self.key_map.get(str_hash)
|
||||
if utf8str is NULL:
|
||||
raise KeyError(Errors.E018.format(hash_value=str_hash))
|
||||
else:
|
||||
return self._decode_str_repr(utf8str)
|
||||
|
||||
cdef hash_t _intern_str(self, str string):
|
||||
# TODO: This function's API/behaviour is an unholy mess...
|
||||
# 0 means missing, but we don't bother offsetting the index.
|
||||
cdef hash_t key = precalculated_hash[0] if precalculated_hash is not NULL else _hash_utf8(utf8_string, length)
|
||||
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
||||
chars = string.encode('utf-8')
|
||||
cdef hash_t key = hash64(<unsigned char*>chars, len(chars), 1)
|
||||
cdef Utf8Str* value = <Utf8Str*>self.key_map.get(key)
|
||||
if value is not NULL:
|
||||
return value
|
||||
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
|
||||
self._map.set(key, value)
|
||||
return key
|
||||
|
||||
value = self._allocate_str_repr(<unsigned char*>chars, len(chars))
|
||||
self.key_map.set(key, value)
|
||||
self.keys.push_back(key)
|
||||
return value
|
||||
return key
|
||||
|
||||
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *:
|
||||
cdef int n_length_bytes
|
||||
cdef int i
|
||||
cdef Utf8Str* string = <Utf8Str*>self.mem.alloc(1, sizeof(Utf8Str))
|
||||
cdef uint32_t ulength = length
|
||||
if length < sizeof(string.s):
|
||||
string.s[0] = <unsigned char>length
|
||||
memcpy(&string.s[1], chars, length)
|
||||
return string
|
||||
elif length < 255:
|
||||
string.p = <unsigned char*>self.mem.alloc(length + 1, sizeof(unsigned char))
|
||||
string.p[0] = length
|
||||
memcpy(&string.p[1], chars, length)
|
||||
return string
|
||||
else:
|
||||
i = 0
|
||||
n_length_bytes = (length // 255) + 1
|
||||
string.p = <unsigned char*>self.mem.alloc(length + n_length_bytes, sizeof(unsigned char))
|
||||
for i in range(n_length_bytes-1):
|
||||
string.p[i] = 255
|
||||
string.p[n_length_bytes-1] = length % 255
|
||||
memcpy(&string.p[n_length_bytes], chars, length)
|
||||
return string
|
||||
|
||||
cdef hash_t _hash_utf8(char* utf8_string, int length) nogil:
|
||||
return hash64(utf8_string, length, 1)
|
||||
|
||||
|
||||
cdef str _decode_Utf8Str(const Utf8Str* string):
|
||||
cdef str _decode_str_repr(self, const Utf8Str* string):
|
||||
cdef int i, length
|
||||
if string.s[0] < sizeof(string.s) and string.s[0] != 0:
|
||||
return string.s[1:string.s[0]+1].decode('utf-8')
|
||||
|
@ -279,6 +251,39 @@ cdef str _decode_Utf8Str(const Utf8Str* string):
|
|||
return string.p[i:length + i].decode('utf-8')
|
||||
|
||||
|
||||
cpdef hash_t hash_string(object string) except -1:
|
||||
if not isinstance(string, str):
|
||||
raise TypeError(Errors.E4000.format(expected_types="'str'", received_type=type(string)))
|
||||
|
||||
# Handle reserved symbols and empty strings correctly.
|
||||
if len(string) == 0:
|
||||
return 0
|
||||
|
||||
symbol = SYMBOLS_BY_STR.get(string, None)
|
||||
if symbol is not None:
|
||||
return symbol
|
||||
|
||||
chars = string.encode('utf-8')
|
||||
return hash64(<unsigned char*>chars, len(chars), 1)
|
||||
|
||||
|
||||
cpdef hash_t get_string_id(object string_or_hash) except -1:
|
||||
cdef hash_t str_hash
|
||||
|
||||
try:
|
||||
return hash_string(string_or_hash)
|
||||
except:
|
||||
if _try_coerce_to_hash(string_or_hash, &str_hash):
|
||||
# Coerce the integral key to the expected primitive hash type.
|
||||
# This ensures that custom/overloaded "primitive" data types
|
||||
# such as those implemented by numpy are not inadvertently used
|
||||
# downsteam (as these are internally implemented as custom PyObjects
|
||||
# whose comparison operators can incur a significant overhead).
|
||||
return str_hash
|
||||
else:
|
||||
raise TypeError(Errors.E4000.format(expected_types="'str','int'", received_type=type(string_or_hash)))
|
||||
|
||||
|
||||
# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)`
|
||||
cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
|
||||
try:
|
||||
|
@ -287,27 +292,3 @@ cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
|
|||
except:
|
||||
return False
|
||||
|
||||
|
||||
cdef Utf8Str* _allocate(Pool mem, const unsigned char* chars, uint32_t length) except *:
|
||||
cdef int n_length_bytes
|
||||
cdef int i
|
||||
cdef Utf8Str* string = <Utf8Str*>mem.alloc(1, sizeof(Utf8Str))
|
||||
cdef uint32_t ulength = length
|
||||
if length < sizeof(string.s):
|
||||
string.s[0] = <unsigned char>length
|
||||
memcpy(&string.s[1], chars, length)
|
||||
return string
|
||||
elif length < 255:
|
||||
string.p = <unsigned char*>mem.alloc(length + 1, sizeof(unsigned char))
|
||||
string.p[0] = length
|
||||
memcpy(&string.p[1], chars, length)
|
||||
return string
|
||||
else:
|
||||
i = 0
|
||||
n_length_bytes = (length // 255) + 1
|
||||
string.p = <unsigned char*>mem.alloc(length + n_length_bytes, sizeof(unsigned char))
|
||||
for i in range(n_length_bytes-1):
|
||||
string.p[i] = 255
|
||||
string.p[n_length_bytes-1] = length % 255
|
||||
memcpy(&string.p[n_length_bytes], chars, length)
|
||||
return string
|
||||
|
|
Loading…
Reference in New Issue
Block a user