Fix memory zones

This commit is contained in:
Matthew Honnibal 2024-09-09 13:49:41 +02:00
parent 59ac7e6bdb
commit a019315534
9 changed files with 76 additions and 71 deletions

View File

@ -1,5 +1,5 @@
from .lex_attrs import LEX_ATTRS
from ...language import BaseDefaults, Language from ...language import BaseDefaults, Language
from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS from .stop_words import STOP_WORDS

View File

@ -1,6 +1,5 @@
from ...attrs import LIKE_NUM from ...attrs import LIKE_NUM
_num_words = [ _num_words = [
"sifir", "sifir",
"yek", "yek",

View File

@ -5,7 +5,7 @@ import multiprocessing as mp
import random import random
import traceback import traceback
import warnings import warnings
from contextlib import contextmanager from contextlib import ExitStack, contextmanager
from copy import deepcopy from copy import deepcopy
from dataclasses import dataclass from dataclasses import dataclass
from itertools import chain, cycle from itertools import chain, cycle
@ -31,6 +31,7 @@ from typing import (
) )
import srsly import srsly
from cymem.cymem import Pool
from thinc.api import Config, CupyOps, Optimizer, get_current_ops from thinc.api import Config, CupyOps, Optimizer, get_current_ops
from . import about, ty, util from . import about, ty, util
@ -2091,6 +2092,38 @@ class Language:
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined] util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name) tok2vec.remove_listener(listener, pipe_name)
@contextmanager
def memory_zone(self, mem: Optional[Pool]=None) -> Iterator[Pool]:
"""Begin a block where all resources allocated during the block will
be freed at the end of it. If a resources was created within the
memory zone block, accessing it outside the block is invalid.
Behaviour of this invalid access is undefined. Memory zones should
not be nested.
The memory zone is helpful for services that need to process large
volumes of text with a defined memory budget.
Example
-------
>>> with nlp.memory_zone():
... for doc in nlp.pipe(texts):
... process_my_doc(doc)
>>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
"""
if mem is None:
mem = Pool()
# The ExitStack allows programmatic nested context managers.
# We don't know how many we need, so it would be awkward to have
# them as nested blocks.
with ExitStack() as stack:
contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
if hasattr(self.tokenizer, "memory_zone"):
contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
for _, pipe in self.pipeline:
if hasattr(pipe, "memory_zone"):
contexts.append(stack.enter_context(pipe.memory_zone(mem)))
yield mem
def to_disk( def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None: ) -> None:

View File

@ -203,7 +203,7 @@ cdef class ArcEagerGold:
def __init__(self, ArcEager moves, StateClass stcls, Example example): def __init__(self, ArcEager moves, StateClass stcls, Example example):
self.mem = Pool() self.mem = Pool()
heads, labels = example.get_aligned_parse(projectivize=True) heads, labels = example.get_aligned_parse(projectivize=True)
labels = [example.x.vocab.strings.add(label) if label is not None else MISSING_DEP for label in labels] labels = [example.x.vocab.strings.add(label, allow_transient=False) if label is not None else MISSING_DEP for label in labels]
sent_starts = _get_aligned_sent_starts(example) sent_starts = _get_aligned_sent_starts(example)
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts)) assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts) self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)

View File

@ -183,7 +183,7 @@ cpdef deprojectivize(Doc doc):
new_label, head_label = label.split(DELIMITER) new_label, head_label = label.split(DELIMITER)
new_head = _find_new_head(doc[i], head_label) new_head = _find_new_head(doc[i], head_label)
doc.c[i].head = new_head.i - i doc.c[i].head = new_head.i - i
doc.c[i].dep = doc.vocab.strings.add(new_label) doc.c[i].dep = doc.vocab.strings.add(new_label, allow_transient=False)
set_children_from_heads(doc.c, 0, doc.length) set_children_from_heads(doc.c, 0, doc.length)
return doc return doc

View File

@ -28,5 +28,4 @@ cdef class StringStore:
cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient) cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient)
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient) cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient)
cdef vector[hash_t] _transient_keys cdef vector[hash_t] _transient_keys
cdef PreshMap _transient_map
cdef Pool _non_temp_mem cdef Pool _non_temp_mem

View File

@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t
from libc.string cimport memcpy from libc.string cimport memcpy
from murmurhash.mrmr cimport hash32, hash64 from murmurhash.mrmr cimport hash32, hash64
from preshed.maps cimport map_clear
import srsly import srsly
@ -125,10 +126,9 @@ cdef class StringStore:
self.mem = Pool() self.mem = Pool()
self._non_temp_mem = self.mem self._non_temp_mem = self.mem
self._map = PreshMap() self._map = PreshMap()
self._transient_map = None
if strings is not None: if strings is not None:
for string in strings: for string in strings:
self.add(string) self.add(string, allow_transient=False)
def __getitem__(self, object string_or_id): def __getitem__(self, object string_or_id):
"""Retrieve a string from a given hash, or vice versa. """Retrieve a string from a given hash, or vice versa.
@ -158,17 +158,17 @@ cdef class StringStore:
return SYMBOLS_BY_INT[str_hash] return SYMBOLS_BY_INT[str_hash]
else: else:
utf8str = <Utf8Str*>self._map.get(str_hash) utf8str = <Utf8Str*>self._map.get(str_hash)
if utf8str is NULL and self._transient_map is not None: if utf8str is NULL:
utf8str = <Utf8Str*>self._transient_map.get(str_hash) raise KeyError(Errors.E018.format(hash_value=string_or_id))
else:
return decode_Utf8Str(utf8str)
else: else:
# TODO: Raise an error instead # TODO: Raise an error instead
utf8str = <Utf8Str*>self._map.get(string_or_id) utf8str = <Utf8Str*>self._map.get(string_or_id)
if utf8str is NULL and self._transient_map is not None: if utf8str is NULL:
utf8str = <Utf8Str*>self._transient_map.get(str_hash) raise KeyError(Errors.E018.format(hash_value=string_or_id))
if utf8str is NULL: else:
raise KeyError(Errors.E018.format(hash_value=string_or_id)) return decode_Utf8Str(utf8str)
else:
return decode_Utf8Str(utf8str)
def as_int(self, key): def as_int(self, key):
"""If key is an int, return it; otherwise, get the int value.""" """If key is an int, return it; otherwise, get the int value."""
@ -184,16 +184,12 @@ cdef class StringStore:
else: else:
return self[key] return self[key]
def __reduce__(self):
strings = list(self.non_transient_keys())
return (StringStore, (strings,), None, None, None)
def __len__(self) -> int: def __len__(self) -> int:
"""The number of strings in the store. """The number of strings in the store.
RETURNS (int): The number of strings in the store. RETURNS (int): The number of strings in the store.
""" """
return self._keys.size() + self._transient_keys.size() return self.keys.size() + self._transient_keys.size()
@contextmanager @contextmanager
def memory_zone(self, mem: Optional[Pool] = None) -> Pool: def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
@ -209,13 +205,13 @@ cdef class StringStore:
if mem is None: if mem is None:
mem = Pool() mem = Pool()
self.mem = mem self.mem = mem
self._transient_map = PreshMap()
yield mem yield mem
self.mem = self._non_temp_mem for key in self._transient_keys:
self._transient_map = None map_clear(self._map.c_map, key)
self._transient_keys.clear() self._transient_keys.clear()
self.mem = self._non_temp_mem
def add(self, string: str, allow_transient: bool = False) -> int: def add(self, string: str, allow_transient: Optional[bool] = None) -> int:
"""Add a string to the StringStore. """Add a string to the StringStore.
string (str): The string to add. string (str): The string to add.
@ -226,6 +222,8 @@ cdef class StringStore:
internally should not. internally should not.
RETURNS (uint64): The string's hash value. RETURNS (uint64): The string's hash value.
""" """
if allow_transient is None:
allow_transient = self.mem is not self._non_temp_mem
cdef hash_t str_hash cdef hash_t str_hash
if isinstance(string, str): if isinstance(string, str):
if string in SYMBOLS_BY_STR: if string in SYMBOLS_BY_STR:
@ -273,8 +271,6 @@ cdef class StringStore:
# TODO: Raise an error instead # TODO: Raise an error instead
if self._map.get(string_or_id) is not NULL: if self._map.get(string_or_id) is not NULL:
return True return True
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
return True
else: else:
return False return False
if str_hash < len(SYMBOLS_BY_INT): if str_hash < len(SYMBOLS_BY_INT):
@ -282,8 +278,6 @@ cdef class StringStore:
else: else:
if self._map.get(str_hash) is not NULL: if self._map.get(str_hash) is not NULL:
return True return True
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
return True
else: else:
return False return False
@ -292,32 +286,21 @@ cdef class StringStore:
YIELDS (str): A string in the store. YIELDS (str): A string in the store.
""" """
yield from self.non_transient_keys()
yield from self.transient_keys()
def non_transient_keys(self) -> Iterator[str]:
"""Iterate over the stored strings in insertion order.
RETURNS: A list of strings.
"""
cdef int i cdef int i
cdef hash_t key cdef hash_t key
for i in range(self.keys.size()): for i in range(self.keys.size()):
key = self.keys[i] key = self.keys[i]
utf8str = <Utf8Str*>self._map.get(key) utf8str = <Utf8Str*>self._map.get(key)
yield decode_Utf8Str(utf8str) yield decode_Utf8Str(utf8str)
for i in range(self._transient_keys.size()):
key = self._transient_keys[i]
utf8str = <Utf8Str*>self._map.get(key)
yield decode_Utf8Str(utf8str)
def __reduce__(self): def __reduce__(self):
strings = list(self) strings = list(self)
return (StringStore, (strings,), None, None, None) return (StringStore, (strings,), None, None, None)
def transient_keys(self) -> Iterator[str]:
if self._transient_map is None:
return []
for i in range(self._transient_keys.size()):
utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
yield decode_Utf8Str(utf8str)
def values(self) -> List[int]: def values(self) -> List[int]:
"""Iterate over the stored strings hashes in insertion order. """Iterate over the stored strings hashes in insertion order.
@ -327,12 +310,9 @@ cdef class StringStore:
hashes = [None] * self._keys.size() hashes = [None] * self._keys.size()
for i in range(self._keys.size()): for i in range(self._keys.size()):
hashes[i] = self._keys[i] hashes[i] = self._keys[i]
if self._transient_map is not None: transient_hashes = [None] * self._transient_keys.size()
transient_hashes = [None] * self._transient_keys.size() for i in range(self._transient_keys.size()):
for i in range(self._transient_keys.size()): transient_hashes[i] = self._transient_keys[i]
transient_hashes[i] = self._transient_keys[i]
else:
transient_hashes = []
return hashes + transient_hashes return hashes + transient_hashes
def to_disk(self, path): def to_disk(self, path):
@ -383,8 +363,10 @@ cdef class StringStore:
def _reset_and_load(self, strings): def _reset_and_load(self, strings):
self.mem = Pool() self.mem = Pool()
self._non_temp_mem = self.mem
self._map = PreshMap() self._map = PreshMap()
self.keys.clear() self.keys.clear()
self._transient_keys.clear()
for string in strings: for string in strings:
self.add(string, allow_transient=False) self.add(string, allow_transient=False)
@ -401,19 +383,10 @@ cdef class StringStore:
cdef Utf8Str* value = <Utf8Str*>self._map.get(key) cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
if value is not NULL: if value is not NULL:
return value return value
if allow_transient and self._transient_map is not None:
# If we've already allocated a transient string, and now we
# want to intern it permanently, we'll end up with the string
# in both places. That seems fine -- I don't see why we need
# to remove it from the transient map.
value = <Utf8Str*>self._transient_map.get(key)
if value is not NULL:
return value
value = _allocate(self.mem, <unsigned char*>utf8_string, length) value = _allocate(self.mem, <unsigned char*>utf8_string, length)
if allow_transient and self._transient_map is not None: self._map.set(key, value)
self._transient_map.set(key, value) if allow_transient and self.mem is not self._non_temp_mem:
self._transient_keys.push_back(key) self._transient_keys.push_back(key)
else: else:
self._map.set(key, value)
self.keys.push_back(key) self.keys.push_back(key)
return value return value

View File

@ -517,12 +517,8 @@ cdef class Tokenizer:
if n <= 0: if n <= 0:
# avoid mem alloc of zero length # avoid mem alloc of zero length
return 0 return 0
# Historically this check was mostly used to avoid caching if self.vocab.in_memory_zone:
# chunks that had tokens owned by the Doc. Now that that's return 0
# not a thing, I don't think we need this?
for i in range(n):
if self.vocab._by_orth.get(tokens[i].lex.orth) == NULL:
return 0
# See #1250 # See #1250
if has_special[0]: if has_special[0]:
return 0 return 0

View File

@ -5,6 +5,7 @@ from typing import Iterator, Optional
import numpy import numpy
import srsly import srsly
from thinc.api import get_array_module, get_current_ops from thinc.api import get_array_module, get_current_ops
from preshed.maps cimport map_clear
from .attrs cimport LANG, ORTH from .attrs cimport LANG, ORTH
from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme
@ -104,7 +105,7 @@ cdef class Vocab:
def vectors(self, vectors): def vectors(self, vectors):
if hasattr(vectors, "strings"): if hasattr(vectors, "strings"):
for s in vectors.strings: for s in vectors.strings:
self.strings.add(s) self.strings.add(s, allow_transient=False)
self._vectors = vectors self._vectors = vectors
self._vectors.strings = self.strings self._vectors.strings = self.strings
@ -115,6 +116,10 @@ cdef class Vocab:
langfunc = self.lex_attr_getters.get(LANG, None) langfunc = self.lex_attr_getters.get(LANG, None)
return langfunc("_") if langfunc else "" return langfunc("_") if langfunc else ""
@property
def in_memory_zone(self) -> bool:
return self.mem is not self._non_temp_mem
def __len__(self): def __len__(self):
"""The current number of lexemes stored. """The current number of lexemes stored.
@ -218,7 +223,7 @@ cdef class Vocab:
# this size heuristic. # this size heuristic.
mem = self.mem mem = self.mem
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC)) lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
lex.orth = self.strings.add(string) lex.orth = self.strings.add(string, allow_transient=True)
lex.length = len(string) lex.length = len(string)
if self.vectors is not None and hasattr(self.vectors, "key2row"): if self.vectors is not None and hasattr(self.vectors, "key2row"):
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK) lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
@ -239,13 +244,13 @@ cdef class Vocab:
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1: cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1:
self._by_orth.set(lex.orth, <void*>lex) self._by_orth.set(lex.orth, <void*>lex)
self.length += 1 self.length += 1
if is_transient: if is_transient and self.in_memory_zone:
self._transient_orths.push_back(lex.orth) self._transient_orths.push_back(lex.orth)
def _clear_transient_orths(self): def _clear_transient_orths(self):
"""Remove transient lexemes from the index (generally at the end of the memory zone)""" """Remove transient lexemes from the index (generally at the end of the memory zone)"""
for orth in self._transient_orths: for orth in self._transient_orths:
self._by_orth.pop(orth) map_clear(self._by_orth.c_map, orth)
self._transient_orths.clear() self._transient_orths.clear()
def __contains__(self, key): def __contains__(self, key):