Fix memory zones

This commit is contained in:
Matthew Honnibal 2024-09-09 13:49:41 +02:00
parent 59ac7e6bdb
commit a019315534
9 changed files with 76 additions and 71 deletions

View File

@ -1,5 +1,5 @@
from .lex_attrs import LEX_ATTRS
from ...language import BaseDefaults, Language
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS

View File

@ -1,6 +1,5 @@
from ...attrs import LIKE_NUM
_num_words = [
"sifir",
"yek",

View File

@ -5,7 +5,7 @@ import multiprocessing as mp
import random
import traceback
import warnings
from contextlib import contextmanager
from contextlib import ExitStack, contextmanager
from copy import deepcopy
from dataclasses import dataclass
from itertools import chain, cycle
@ -31,6 +31,7 @@ from typing import (
)
import srsly
from cymem.cymem import Pool
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
from . import about, ty, util
@ -2091,6 +2092,38 @@ class Language:
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)
@contextmanager
def memory_zone(self, mem: Optional[Pool]=None) -> Iterator[Pool]:
"""Begin a block where all resources allocated during the block will
be freed at the end of it. If a resources was created within the
memory zone block, accessing it outside the block is invalid.
Behaviour of this invalid access is undefined. Memory zones should
not be nested.
The memory zone is helpful for services that need to process large
volumes of text with a defined memory budget.
Example
-------
>>> with nlp.memory_zone():
... for doc in nlp.pipe(texts):
... process_my_doc(doc)
>>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
"""
if mem is None:
mem = Pool()
# The ExitStack allows programmatic nested context managers.
# We don't know how many we need, so it would be awkward to have
# them as nested blocks.
with ExitStack() as stack:
contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
if hasattr(self.tokenizer, "memory_zone"):
contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
for _, pipe in self.pipeline:
if hasattr(pipe, "memory_zone"):
contexts.append(stack.enter_context(pipe.memory_zone(mem)))
yield mem
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None:

View File

@ -203,7 +203,7 @@ cdef class ArcEagerGold:
def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool()
heads, labels = example.get_aligned_parse(projectivize=True)
labels = [example.x.vocab.strings.add(label) if label is not None else MISSING_DEP for label in labels]
labels = [example.x.vocab.strings.add(label, allow_transient=False) if label is not None else MISSING_DEP for label in labels]
sent_starts = _get_aligned_sent_starts(example)
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)

View File

@ -183,7 +183,7 @@ cpdef deprojectivize(Doc doc):
new_label, head_label = label.split(DELIMITER)
new_head = _find_new_head(doc[i], head_label)
doc.c[i].head = new_head.i - i
doc.c[i].dep = doc.vocab.strings.add(new_label)
doc.c[i].dep = doc.vocab.strings.add(new_label, allow_transient=False)
set_children_from_heads(doc.c, 0, doc.length)
return doc

View File

@ -28,5 +28,4 @@ cdef class StringStore:
cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient)
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient)
cdef vector[hash_t] _transient_keys
cdef PreshMap _transient_map
cdef Pool _non_temp_mem

View File

@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
from libc.stdint cimport uint32_t
from libc.string cimport memcpy
from murmurhash.mrmr cimport hash32, hash64
from preshed.maps cimport map_clear
import srsly
@ -125,10 +126,9 @@ cdef class StringStore:
self.mem = Pool()
self._non_temp_mem = self.mem
self._map = PreshMap()
self._transient_map = None
if strings is not None:
for string in strings:
self.add(string)
self.add(string, allow_transient=False)
def __getitem__(self, object string_or_id):
"""Retrieve a string from a given hash, or vice versa.
@ -158,17 +158,17 @@ cdef class StringStore:
return SYMBOLS_BY_INT[str_hash]
else:
utf8str = <Utf8Str*>self._map.get(str_hash)
if utf8str is NULL and self._transient_map is not None:
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return decode_Utf8Str(utf8str)
else:
# TODO: Raise an error instead
utf8str = <Utf8Str*>self._map.get(string_or_id)
if utf8str is NULL and self._transient_map is not None:
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return decode_Utf8Str(utf8str)
if utf8str is NULL:
raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return decode_Utf8Str(utf8str)
def as_int(self, key):
"""If key is an int, return it; otherwise, get the int value."""
@ -184,16 +184,12 @@ cdef class StringStore:
else:
return self[key]
def __reduce__(self):
strings = list(self.non_transient_keys())
return (StringStore, (strings,), None, None, None)
def __len__(self) -> int:
"""The number of strings in the store.
RETURNS (int): The number of strings in the store.
"""
return self._keys.size() + self._transient_keys.size()
return self.keys.size() + self._transient_keys.size()
@contextmanager
def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
@ -209,13 +205,13 @@ cdef class StringStore:
if mem is None:
mem = Pool()
self.mem = mem
self._transient_map = PreshMap()
yield mem
self.mem = self._non_temp_mem
self._transient_map = None
for key in self._transient_keys:
map_clear(self._map.c_map, key)
self._transient_keys.clear()
self.mem = self._non_temp_mem
def add(self, string: str, allow_transient: bool = False) -> int:
def add(self, string: str, allow_transient: Optional[bool] = None) -> int:
"""Add a string to the StringStore.
string (str): The string to add.
@ -226,6 +222,8 @@ cdef class StringStore:
internally should not.
RETURNS (uint64): The string's hash value.
"""
if allow_transient is None:
allow_transient = self.mem is not self._non_temp_mem
cdef hash_t str_hash
if isinstance(string, str):
if string in SYMBOLS_BY_STR:
@ -273,8 +271,6 @@ cdef class StringStore:
# TODO: Raise an error instead
if self._map.get(string_or_id) is not NULL:
return True
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
return True
else:
return False
if str_hash < len(SYMBOLS_BY_INT):
@ -282,8 +278,6 @@ cdef class StringStore:
else:
if self._map.get(str_hash) is not NULL:
return True
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
return True
else:
return False
@ -292,32 +286,21 @@ cdef class StringStore:
YIELDS (str): A string in the store.
"""
yield from self.non_transient_keys()
yield from self.transient_keys()
def non_transient_keys(self) -> Iterator[str]:
"""Iterate over the stored strings in insertion order.
RETURNS: A list of strings.
"""
cdef int i
cdef hash_t key
for i in range(self.keys.size()):
key = self.keys[i]
utf8str = <Utf8Str*>self._map.get(key)
yield decode_Utf8Str(utf8str)
for i in range(self._transient_keys.size()):
key = self._transient_keys[i]
utf8str = <Utf8Str*>self._map.get(key)
yield decode_Utf8Str(utf8str)
def __reduce__(self):
strings = list(self)
return (StringStore, (strings,), None, None, None)
def transient_keys(self) -> Iterator[str]:
if self._transient_map is None:
return []
for i in range(self._transient_keys.size()):
utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
yield decode_Utf8Str(utf8str)
def values(self) -> List[int]:
"""Iterate over the stored strings hashes in insertion order.
@ -327,12 +310,9 @@ cdef class StringStore:
hashes = [None] * self._keys.size()
for i in range(self._keys.size()):
hashes[i] = self._keys[i]
if self._transient_map is not None:
transient_hashes = [None] * self._transient_keys.size()
for i in range(self._transient_keys.size()):
transient_hashes[i] = self._transient_keys[i]
else:
transient_hashes = []
transient_hashes = [None] * self._transient_keys.size()
for i in range(self._transient_keys.size()):
transient_hashes[i] = self._transient_keys[i]
return hashes + transient_hashes
def to_disk(self, path):
@ -383,8 +363,10 @@ cdef class StringStore:
def _reset_and_load(self, strings):
self.mem = Pool()
self._non_temp_mem = self.mem
self._map = PreshMap()
self.keys.clear()
self._transient_keys.clear()
for string in strings:
self.add(string, allow_transient=False)
@ -401,19 +383,10 @@ cdef class StringStore:
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
if value is not NULL:
return value
if allow_transient and self._transient_map is not None:
# If we've already allocated a transient string, and now we
# want to intern it permanently, we'll end up with the string
# in both places. That seems fine -- I don't see why we need
# to remove it from the transient map.
value = <Utf8Str*>self._transient_map.get(key)
if value is not NULL:
return value
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
if allow_transient and self._transient_map is not None:
self._transient_map.set(key, value)
self._map.set(key, value)
if allow_transient and self.mem is not self._non_temp_mem:
self._transient_keys.push_back(key)
else:
self._map.set(key, value)
self.keys.push_back(key)
return value

View File

@ -517,12 +517,8 @@ cdef class Tokenizer:
if n <= 0:
# avoid mem alloc of zero length
return 0
# Historically this check was mostly used to avoid caching
# chunks that had tokens owned by the Doc. Now that that's
# not a thing, I don't think we need this?
for i in range(n):
if self.vocab._by_orth.get(tokens[i].lex.orth) == NULL:
return 0
if self.vocab.in_memory_zone:
return 0
# See #1250
if has_special[0]:
return 0

View File

@ -5,6 +5,7 @@ from typing import Iterator, Optional
import numpy
import srsly
from thinc.api import get_array_module, get_current_ops
from preshed.maps cimport map_clear
from .attrs cimport LANG, ORTH
from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme
@ -104,7 +105,7 @@ cdef class Vocab:
def vectors(self, vectors):
if hasattr(vectors, "strings"):
for s in vectors.strings:
self.strings.add(s)
self.strings.add(s, allow_transient=False)
self._vectors = vectors
self._vectors.strings = self.strings
@ -115,6 +116,10 @@ cdef class Vocab:
langfunc = self.lex_attr_getters.get(LANG, None)
return langfunc("_") if langfunc else ""
@property
def in_memory_zone(self) -> bool:
return self.mem is not self._non_temp_mem
def __len__(self):
"""The current number of lexemes stored.
@ -218,7 +223,7 @@ cdef class Vocab:
# this size heuristic.
mem = self.mem
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
lex.orth = self.strings.add(string)
lex.orth = self.strings.add(string, allow_transient=True)
lex.length = len(string)
if self.vectors is not None and hasattr(self.vectors, "key2row"):
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
@ -239,13 +244,13 @@ cdef class Vocab:
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1:
self._by_orth.set(lex.orth, <void*>lex)
self.length += 1
if is_transient:
if is_transient and self.in_memory_zone:
self._transient_orths.push_back(lex.orth)
def _clear_transient_orths(self):
"""Remove transient lexemes from the index (generally at the end of the memory zone)"""
for orth in self._transient_orths:
self._by_orth.pop(orth)
map_clear(self._by_orth.c_map, orth)
self._transient_orths.clear()
def __contains__(self, key):