spaCy/spacy/strings.pyx
Matthew Honnibal 5d0d2de955 Support 'memory zones' for user memory management
Add a context manage nlp.memory_zone(), which will begin
memory_zone() blocks on the vocab, string store, and potentially
other components.

Once the memory_zone() block expires, spaCy will free any shared
resources that were allocated for the text-processing that occurred
within the memory_zone. If you create Doc objects within a memory
zone, it's invalid to access them once the memory zone is expired.

The purpose of this is that spaCy creates and stores Lexeme objects
in the Vocab that can be shared between multiple Doc objects. It also
interns strings. Normally, spaCy can't know when all Doc objects using
a Lexeme are out-of-scope, so new Lexemes accumulate in the vocab,
causing memory pressure.

Memory zones solve this problem by telling spaCy "okay none of the
documents allocated within this block will be accessed again". This
lets spaCy free all new Lexeme objects and other data that were
created during the block.

The mechanism is general, so memory_zone() context managers can be
added to other components that could benefit from them, e.g. pipeline
components.

I experimented with adding memory zone support to the tokenizer as well,
for its cache. However, this seems unnecessarily complicated. It makes
more sense to just stick a limit on the cache size. This lets spaCy
benefit from the efficiency advantage of the cache better, because
we can maintain a (bounded) cache even if only small batches of
documents are being processed.
2024-09-08 13:06:54 +02:00

385 lines
14 KiB
Cython

# cython: infer_types=True
# cython: profile=False
from typing import Iterable, Iterator, List, Optional, Tuple, Union
from contextlib import contextmanager
from libc.stdint cimport uint32_t
from libc.string cimport memcpy
from murmurhash.mrmr cimport hash64
import srsly
from .typedefs cimport hash_t
from . import util
from .errors import Errors
from .symbols import IDS as SYMBOLS_BY_STR
from .symbols import NAMES as SYMBOLS_BY_INT
cdef class StringStore:
"""Look up strings by 64-bit hashes. Implicitly handles reserved symbols.
DOCS: https://spacy.io/api/stringstore
"""
def __init__(self, strings: Optional[Iterable[str]] = None):
"""Create the StringStore.
strings (iterable): A sequence of unicode strings to add to the store.
"""
self.mem = Pool()
self._non_temp_mem = self.mem
self._map = PreshMap()
self._transient_map = None
if strings is not None:
for string in strings:
self.add(string)
def __getitem__(self, string_or_hash: Union[str, int]) -> Union[str, int]:
"""Retrieve a string from a given hash. If a string
is passed as the input, add it to the store and return
its hash.
string_or_hash (int / str): The hash value to lookup or the string to store.
RETURNS (str / int): The stored string or the hash of the newly added string.
"""
if isinstance(string_or_hash, str):
return self.add(string_or_hash)
else:
return self._get_interned_str(string_or_hash)
def __contains__(self, string_or_hash: Union[str, int]) -> bool:
"""Check whether a string or a hash is in the store.
string (str / int): The string/hash to check.
RETURNS (bool): Whether the store contains the string.
"""
cdef hash_t str_hash = get_string_id(string_or_hash)
if str_hash in SYMBOLS_BY_INT:
return True
else:
return self._map.get(str_hash) is not NULL
def __iter__(self) -> Iterator[str]:
"""Iterate over the strings in the store in insertion order.
RETURNS: An iterable collection of strings.
"""
return iter(self.keys())
def __reduce__(self):
strings = list(self.non_transient_keys())
return (StringStore, (strings,), None, None, None)
def __len__(self) -> int:
"""The number of strings in the store.
RETURNS (int): The number of strings in the store.
"""
return self._keys.size() + self._transient_keys.size()
@contextmanager
def memory_zone(self, mem: Optional[Pool]=None) -> Pool:
"""Begin a block where all resources allocated during the block will
be freed at the end of it. If a resources was created within the
memory zone block, accessing it outside the block is invalid.
Behaviour of this invalid access is undefined. Memory zones should
not be nested.
The memory zone is helpful for services that need to process large
volumes of text with a defined memory budget.
"""
if mem is None:
mem = Pool()
self.mem = mem
self._transient_map = PreshMap()
yield mem
self.mem = self._non_temp_mem
self._transient_map = None
self._transient_keys.clear()
def add(self, string: str, allow_transient: bool = False) -> int:
"""Add a string to the StringStore.
string (str): The string to add.
allow_transient (bool): Allow the string to be stored in the 'transient'
map, which will be flushed at the end of the memory zone. Strings
encountered during arbitrary text processing should be added
with allow_transient=True, while labels and other strings used
internally should not.
RETURNS (uint64): The string's hash value.
"""
if not isinstance(string, str):
raise TypeError(Errors.E017.format(value_type=type(string)))
if string in SYMBOLS_BY_STR:
return SYMBOLS_BY_STR[string]
else:
return self._intern_str(string, allow_transient)
def as_int(self, string_or_hash: Union[str, int]) -> str:
"""If a hash value is passed as the input, return it as-is. If the input
is a string, return its corresponding hash.
string_or_hash (str / int): The string to hash or a hash value.
RETURNS (int): The hash of the string or the input hash value.
"""
if isinstance(string_or_hash, int):
return string_or_hash
else:
return get_string_id(string_or_hash)
def as_string(self, string_or_hash: Union[str, int]) -> str:
"""If a string is passed as the input, return it as-is. If the input
is a hash value, return its corresponding string.
string_or_hash (str / int): The hash value to lookup or a string.
RETURNS (str): The stored string or the input string.
"""
if isinstance(string_or_hash, str):
return string_or_hash
else:
return self._get_interned_str(string_or_hash)
def items(self) -> List[Tuple[str, int]]:
"""Iterate over the stored strings and their hashes in insertion order.
RETURNS: A list of string-hash pairs.
"""
# Even though we internally store the hashes as keys and the strings as
# values, we invert the order in the public API to keep it consistent with
# the implementation of the `__iter__` method (where we wish to iterate over
# the strings in the store).
cdef int i
pairs = [None] * self._keys.size()
for i in range(self._keys.size()):
str_hash = self._keys[i]
utf8str = <Utf8Str*>self._map.get(str_hash)
pairs[i] = (self._decode_str_repr(utf8str), str_hash)
return pairs
def keys(self) -> List[str]:
"""Iterate over the stored strings in insertion order.
RETURNS: A list of strings.
"""
return self.non_transient_keys() + self.transient_keys()
def non_transient_keys(self) -> List[str]:
"""Iterate over the stored strings in insertion order.
RETURNS: A list of strings.
"""
cdef int i
strings = [None] * self._keys.size()
for i in range(self._keys.size()):
utf8str = <Utf8Str*>self._map.get(self._keys[i])
strings[i] = self._decode_str_repr(utf8str)
return strings
def transient_keys(self) -> List[str]:
if self._transient_map is None:
return []
transient_strings = [None] * self._transient_keys.size()
for i in range(self._transient_keys.size()):
utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
transient_strings[i] = self._decode_str_repr(utf8str)
return transient_strings
def values(self) -> List[int]:
"""Iterate over the stored strings hashes in insertion order.
RETURNS: A list of string hashs.
"""
cdef int i
hashes = [None] * self._keys.size()
for i in range(self._keys.size()):
hashes[i] = self._keys[i]
if self._transient_map is not None:
transient_hashes = [None] * self._transient_keys.size()
for i in range(self._transient_keys.size()):
transient_hashes[i] = self._transient_keys[i]
else:
transient_hashes = []
return hashes + transient_hashes
def to_disk(self, path):
"""Save the current state to a directory.
path (str / Path): A path to a directory, which will be created if
it doesn't exist. Paths may be either strings or Path-like objects.
"""
path = util.ensure_path(path)
strings = sorted(self)
srsly.write_json(path, strings)
def from_disk(self, path):
"""Loads state from a directory. Modifies the object in place and
returns it.
path (str / Path): A path to a directory. Paths may be either
strings or `Path`-like objects.
RETURNS (StringStore): The modified `StringStore` object.
"""
path = util.ensure_path(path)
strings = srsly.read_json(path)
prev = list(self)
self._reset_and_load(strings)
for word in prev:
self.add(word, allow_transient=False)
return self
def to_bytes(self, **kwargs):
"""Serialize the current state to a binary string.
RETURNS (bytes): The serialized form of the `StringStore` object.
"""
return srsly.json_dumps(sorted(self))
def from_bytes(self, bytes_data, **kwargs):
"""Load state from a binary string.
bytes_data (bytes): The data to load from.
RETURNS (StringStore): The `StringStore` object.
"""
strings = srsly.json_loads(bytes_data)
prev = list(self)
self._reset_and_load(strings)
for word in prev:
self.add(word, allow_transient=False)
return self
def _reset_and_load(self, strings):
self.mem = Pool()
self._map = PreshMap()
self._keys.clear()
for string in strings:
self.add(string, allow_transient=False)
def _get_interned_str(self, hash_value: int) -> str:
cdef hash_t str_hash
if not _try_coerce_to_hash(hash_value, &str_hash):
raise TypeError(Errors.E4001.format(expected_types="'int'", received_type=type(hash_value)))
# Handle reserved symbols and empty strings correctly.
if str_hash == 0:
return ""
symbol = SYMBOLS_BY_INT.get(str_hash)
if symbol is not None:
return symbol
utf8str = <Utf8Str*>self._map.get(str_hash)
if utf8str is NULL and self._transient_map is not None:
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=str_hash))
else:
return self._decode_str_repr(utf8str)
cdef hash_t _intern_str(self, str string, bint transient):
# TODO: This function's API/behaviour is an unholy mess...
# 0 means missing, but we don't bother offsetting the index.
chars = string.encode('utf-8')
cdef hash_t key = hash64(<unsigned char*>chars, len(chars), 1)
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
if value is not NULL:
return key
if transient and self._transient_map is not None:
# If we've already allocated a transient string, and now we
# want to intern it permanently, we'll end up with the string
# in both places. That seems fine -- I don't see why we need
# to remove it from the transient map.
value = <Utf8Str*>self._transient_map.get(key)
if value is not NULL:
return key
value = self._allocate_str_repr(<unsigned char*>chars, len(chars))
if transient and self._transient_map is not None:
self._transient_map.set(key, value)
self._transient_keys.push_back(key)
else:
self._map.set(key, value)
self._keys.push_back(key)
return key
cdef Utf8Str* _allocate_str_repr(self, const unsigned char* chars, uint32_t length) except *:
cdef int n_length_bytes
cdef int i
cdef Utf8Str* string = <Utf8Str*>self.mem.alloc(1, sizeof(Utf8Str))
if length < sizeof(string.s):
string.s[0] = <unsigned char>length
memcpy(&string.s[1], chars, length)
return string
elif length < 255:
string.p = <unsigned char*>self.mem.alloc(length + 1, sizeof(unsigned char))
string.p[0] = length
memcpy(&string.p[1], chars, length)
return string
else:
i = 0
n_length_bytes = (length // 255) + 1
string.p = <unsigned char*>self.mem.alloc(length + n_length_bytes, sizeof(unsigned char))
for i in range(n_length_bytes-1):
string.p[i] = 255
string.p[n_length_bytes-1] = length % 255
memcpy(&string.p[n_length_bytes], chars, length)
return string
cdef str _decode_str_repr(self, const Utf8Str* string):
cdef int i, length
if string.s[0] < sizeof(string.s) and string.s[0] != 0:
return string.s[1:string.s[0]+1].decode('utf-8')
elif string.p[0] < 255:
return string.p[1:string.p[0]+1].decode('utf-8')
else:
i = 0
length = 0
while string.p[i] == 255:
i += 1
length += 255
length += string.p[i]
i += 1
return string.p[i:length + i].decode('utf-8')
cpdef hash_t hash_string(object string) except -1:
if not isinstance(string, str):
raise TypeError(Errors.E4001.format(expected_types="'str'", received_type=type(string)))
# Handle reserved symbols and empty strings correctly.
if len(string) == 0:
return 0
symbol = SYMBOLS_BY_STR.get(string)
if symbol is not None:
return symbol
chars = string.encode('utf-8')
return hash64(<unsigned char*>chars, len(chars), 1)
cpdef hash_t get_string_id(object string_or_hash) except -1:
cdef hash_t str_hash
try:
return hash_string(string_or_hash)
except: # no-cython-lint
if _try_coerce_to_hash(string_or_hash, &str_hash):
# Coerce the integral key to the expected primitive hash type.
# This ensures that custom/overloaded "primitive" data types
# such as those implemented by numpy are not inadvertently used
# downsteam (as these are internally implemented as custom PyObjects
# whose comparison operators can incur a significant overhead).
return str_hash
else:
raise TypeError(Errors.E4001.format(expected_types="'str','int'", received_type=type(string_or_hash)))
# Not particularly elegant, but this is faster than `isinstance(key, numbers.Integral)`
cdef inline bint _try_coerce_to_hash(object key, hash_t* out_hash):
try:
out_hash[0] = key
return True
except: # no-cython-lint
return False