mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 13:14:11 +03:00
Fix memory zones
This commit is contained in:
parent
59ac7e6bdb
commit
a019315534
|
@ -1,5 +1,5 @@
|
|||
from .lex_attrs import LEX_ATTRS
|
||||
from ...language import BaseDefaults, Language
|
||||
from .lex_attrs import LEX_ATTRS
|
||||
from .stop_words import STOP_WORDS
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
from ...attrs import LIKE_NUM
|
||||
|
||||
|
||||
_num_words = [
|
||||
"sifir",
|
||||
"yek",
|
||||
|
|
|
@ -5,7 +5,7 @@ import multiprocessing as mp
|
|||
import random
|
||||
import traceback
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from contextlib import ExitStack, contextmanager
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain, cycle
|
||||
|
@ -31,6 +31,7 @@ from typing import (
|
|||
)
|
||||
|
||||
import srsly
|
||||
from cymem.cymem import Pool
|
||||
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
|
||||
|
||||
from . import about, ty, util
|
||||
|
@ -2091,6 +2092,38 @@ class Language:
|
|||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||
tok2vec.remove_listener(listener, pipe_name)
|
||||
|
||||
@contextmanager
|
||||
def memory_zone(self, mem: Optional[Pool]=None) -> Iterator[Pool]:
|
||||
"""Begin a block where all resources allocated during the block will
|
||||
be freed at the end of it. If a resources was created within the
|
||||
memory zone block, accessing it outside the block is invalid.
|
||||
Behaviour of this invalid access is undefined. Memory zones should
|
||||
not be nested.
|
||||
|
||||
The memory zone is helpful for services that need to process large
|
||||
volumes of text with a defined memory budget.
|
||||
|
||||
Example
|
||||
-------
|
||||
>>> with nlp.memory_zone():
|
||||
... for doc in nlp.pipe(texts):
|
||||
... process_my_doc(doc)
|
||||
>>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
|
||||
"""
|
||||
if mem is None:
|
||||
mem = Pool()
|
||||
# The ExitStack allows programmatic nested context managers.
|
||||
# We don't know how many we need, so it would be awkward to have
|
||||
# them as nested blocks.
|
||||
with ExitStack() as stack:
|
||||
contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
|
||||
if hasattr(self.tokenizer, "memory_zone"):
|
||||
contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
|
||||
for _, pipe in self.pipeline:
|
||||
if hasattr(pipe, "memory_zone"):
|
||||
contexts.append(stack.enter_context(pipe.memory_zone(mem)))
|
||||
yield mem
|
||||
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
) -> None:
|
||||
|
|
|
@ -203,7 +203,7 @@ cdef class ArcEagerGold:
|
|||
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
||||
self.mem = Pool()
|
||||
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||
labels = [example.x.vocab.strings.add(label) if label is not None else MISSING_DEP for label in labels]
|
||||
labels = [example.x.vocab.strings.add(label, allow_transient=False) if label is not None else MISSING_DEP for label in labels]
|
||||
sent_starts = _get_aligned_sent_starts(example)
|
||||
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
|
||||
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
|
||||
|
|
|
@ -183,7 +183,7 @@ cpdef deprojectivize(Doc doc):
|
|||
new_label, head_label = label.split(DELIMITER)
|
||||
new_head = _find_new_head(doc[i], head_label)
|
||||
doc.c[i].head = new_head.i - i
|
||||
doc.c[i].dep = doc.vocab.strings.add(new_label)
|
||||
doc.c[i].dep = doc.vocab.strings.add(new_label, allow_transient=False)
|
||||
set_children_from_heads(doc.c, 0, doc.length)
|
||||
return doc
|
||||
|
||||
|
|
|
@ -28,5 +28,4 @@ cdef class StringStore:
|
|||
cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient)
|
||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient)
|
||||
cdef vector[hash_t] _transient_keys
|
||||
cdef PreshMap _transient_map
|
||||
cdef Pool _non_temp_mem
|
||||
|
|
|
@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
|
|||
from libc.stdint cimport uint32_t
|
||||
from libc.string cimport memcpy
|
||||
from murmurhash.mrmr cimport hash32, hash64
|
||||
from preshed.maps cimport map_clear
|
||||
|
||||
import srsly
|
||||
|
||||
|
@ -125,10 +126,9 @@ cdef class StringStore:
|
|||
self.mem = Pool()
|
||||
self._non_temp_mem = self.mem
|
||||
self._map = PreshMap()
|
||||
self._transient_map = None
|
||||
if strings is not None:
|
||||
for string in strings:
|
||||
self.add(string)
|
||||
self.add(string, allow_transient=False)
|
||||
|
||||
def __getitem__(self, object string_or_id):
|
||||
"""Retrieve a string from a given hash, or vice versa.
|
||||
|
@ -158,17 +158,17 @@ cdef class StringStore:
|
|||
return SYMBOLS_BY_INT[str_hash]
|
||||
else:
|
||||
utf8str = <Utf8Str*>self._map.get(str_hash)
|
||||
if utf8str is NULL and self._transient_map is not None:
|
||||
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
|
||||
if utf8str is NULL:
|
||||
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
||||
else:
|
||||
return decode_Utf8Str(utf8str)
|
||||
else:
|
||||
# TODO: Raise an error instead
|
||||
utf8str = <Utf8Str*>self._map.get(string_or_id)
|
||||
if utf8str is NULL and self._transient_map is not None:
|
||||
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
|
||||
if utf8str is NULL:
|
||||
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
||||
else:
|
||||
return decode_Utf8Str(utf8str)
|
||||
if utf8str is NULL:
|
||||
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
||||
else:
|
||||
return decode_Utf8Str(utf8str)
|
||||
|
||||
def as_int(self, key):
|
||||
"""If key is an int, return it; otherwise, get the int value."""
|
||||
|
@ -184,16 +184,12 @@ cdef class StringStore:
|
|||
else:
|
||||
return self[key]
|
||||
|
||||
def __reduce__(self):
|
||||
strings = list(self.non_transient_keys())
|
||||
return (StringStore, (strings,), None, None, None)
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""The number of strings in the store.
|
||||
|
||||
RETURNS (int): The number of strings in the store.
|
||||
"""
|
||||
return self._keys.size() + self._transient_keys.size()
|
||||
return self.keys.size() + self._transient_keys.size()
|
||||
|
||||
@contextmanager
|
||||
def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
|
||||
|
@ -209,13 +205,13 @@ cdef class StringStore:
|
|||
if mem is None:
|
||||
mem = Pool()
|
||||
self.mem = mem
|
||||
self._transient_map = PreshMap()
|
||||
yield mem
|
||||
self.mem = self._non_temp_mem
|
||||
self._transient_map = None
|
||||
for key in self._transient_keys:
|
||||
map_clear(self._map.c_map, key)
|
||||
self._transient_keys.clear()
|
||||
self.mem = self._non_temp_mem
|
||||
|
||||
def add(self, string: str, allow_transient: bool = False) -> int:
|
||||
def add(self, string: str, allow_transient: Optional[bool] = None) -> int:
|
||||
"""Add a string to the StringStore.
|
||||
|
||||
string (str): The string to add.
|
||||
|
@ -226,6 +222,8 @@ cdef class StringStore:
|
|||
internally should not.
|
||||
RETURNS (uint64): The string's hash value.
|
||||
"""
|
||||
if allow_transient is None:
|
||||
allow_transient = self.mem is not self._non_temp_mem
|
||||
cdef hash_t str_hash
|
||||
if isinstance(string, str):
|
||||
if string in SYMBOLS_BY_STR:
|
||||
|
@ -273,8 +271,6 @@ cdef class StringStore:
|
|||
# TODO: Raise an error instead
|
||||
if self._map.get(string_or_id) is not NULL:
|
||||
return True
|
||||
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
if str_hash < len(SYMBOLS_BY_INT):
|
||||
|
@ -282,8 +278,6 @@ cdef class StringStore:
|
|||
else:
|
||||
if self._map.get(str_hash) is not NULL:
|
||||
return True
|
||||
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
@ -292,32 +286,21 @@ cdef class StringStore:
|
|||
|
||||
YIELDS (str): A string in the store.
|
||||
"""
|
||||
yield from self.non_transient_keys()
|
||||
yield from self.transient_keys()
|
||||
|
||||
def non_transient_keys(self) -> Iterator[str]:
|
||||
"""Iterate over the stored strings in insertion order.
|
||||
|
||||
RETURNS: A list of strings.
|
||||
"""
|
||||
cdef int i
|
||||
cdef hash_t key
|
||||
for i in range(self.keys.size()):
|
||||
key = self.keys[i]
|
||||
utf8str = <Utf8Str*>self._map.get(key)
|
||||
yield decode_Utf8Str(utf8str)
|
||||
for i in range(self._transient_keys.size()):
|
||||
key = self._transient_keys[i]
|
||||
utf8str = <Utf8Str*>self._map.get(key)
|
||||
yield decode_Utf8Str(utf8str)
|
||||
|
||||
def __reduce__(self):
|
||||
strings = list(self)
|
||||
return (StringStore, (strings,), None, None, None)
|
||||
|
||||
def transient_keys(self) -> Iterator[str]:
|
||||
if self._transient_map is None:
|
||||
return []
|
||||
for i in range(self._transient_keys.size()):
|
||||
utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
|
||||
yield decode_Utf8Str(utf8str)
|
||||
|
||||
def values(self) -> List[int]:
|
||||
"""Iterate over the stored strings hashes in insertion order.
|
||||
|
||||
|
@ -327,12 +310,9 @@ cdef class StringStore:
|
|||
hashes = [None] * self._keys.size()
|
||||
for i in range(self._keys.size()):
|
||||
hashes[i] = self._keys[i]
|
||||
if self._transient_map is not None:
|
||||
transient_hashes = [None] * self._transient_keys.size()
|
||||
for i in range(self._transient_keys.size()):
|
||||
transient_hashes[i] = self._transient_keys[i]
|
||||
else:
|
||||
transient_hashes = []
|
||||
transient_hashes = [None] * self._transient_keys.size()
|
||||
for i in range(self._transient_keys.size()):
|
||||
transient_hashes[i] = self._transient_keys[i]
|
||||
return hashes + transient_hashes
|
||||
|
||||
def to_disk(self, path):
|
||||
|
@ -383,8 +363,10 @@ cdef class StringStore:
|
|||
|
||||
def _reset_and_load(self, strings):
|
||||
self.mem = Pool()
|
||||
self._non_temp_mem = self.mem
|
||||
self._map = PreshMap()
|
||||
self.keys.clear()
|
||||
self._transient_keys.clear()
|
||||
for string in strings:
|
||||
self.add(string, allow_transient=False)
|
||||
|
||||
|
@ -401,19 +383,10 @@ cdef class StringStore:
|
|||
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
||||
if value is not NULL:
|
||||
return value
|
||||
if allow_transient and self._transient_map is not None:
|
||||
# If we've already allocated a transient string, and now we
|
||||
# want to intern it permanently, we'll end up with the string
|
||||
# in both places. That seems fine -- I don't see why we need
|
||||
# to remove it from the transient map.
|
||||
value = <Utf8Str*>self._transient_map.get(key)
|
||||
if value is not NULL:
|
||||
return value
|
||||
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
|
||||
if allow_transient and self._transient_map is not None:
|
||||
self._transient_map.set(key, value)
|
||||
self._map.set(key, value)
|
||||
if allow_transient and self.mem is not self._non_temp_mem:
|
||||
self._transient_keys.push_back(key)
|
||||
else:
|
||||
self._map.set(key, value)
|
||||
self.keys.push_back(key)
|
||||
return value
|
||||
|
|
|
@ -517,12 +517,8 @@ cdef class Tokenizer:
|
|||
if n <= 0:
|
||||
# avoid mem alloc of zero length
|
||||
return 0
|
||||
# Historically this check was mostly used to avoid caching
|
||||
# chunks that had tokens owned by the Doc. Now that that's
|
||||
# not a thing, I don't think we need this?
|
||||
for i in range(n):
|
||||
if self.vocab._by_orth.get(tokens[i].lex.orth) == NULL:
|
||||
return 0
|
||||
if self.vocab.in_memory_zone:
|
||||
return 0
|
||||
# See #1250
|
||||
if has_special[0]:
|
||||
return 0
|
||||
|
|
|
@ -5,6 +5,7 @@ from typing import Iterator, Optional
|
|||
import numpy
|
||||
import srsly
|
||||
from thinc.api import get_array_module, get_current_ops
|
||||
from preshed.maps cimport map_clear
|
||||
|
||||
from .attrs cimport LANG, ORTH
|
||||
from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme
|
||||
|
@ -104,7 +105,7 @@ cdef class Vocab:
|
|||
def vectors(self, vectors):
|
||||
if hasattr(vectors, "strings"):
|
||||
for s in vectors.strings:
|
||||
self.strings.add(s)
|
||||
self.strings.add(s, allow_transient=False)
|
||||
self._vectors = vectors
|
||||
self._vectors.strings = self.strings
|
||||
|
||||
|
@ -115,6 +116,10 @@ cdef class Vocab:
|
|||
langfunc = self.lex_attr_getters.get(LANG, None)
|
||||
return langfunc("_") if langfunc else ""
|
||||
|
||||
@property
|
||||
def in_memory_zone(self) -> bool:
|
||||
return self.mem is not self._non_temp_mem
|
||||
|
||||
def __len__(self):
|
||||
"""The current number of lexemes stored.
|
||||
|
||||
|
@ -218,7 +223,7 @@ cdef class Vocab:
|
|||
# this size heuristic.
|
||||
mem = self.mem
|
||||
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
|
||||
lex.orth = self.strings.add(string)
|
||||
lex.orth = self.strings.add(string, allow_transient=True)
|
||||
lex.length = len(string)
|
||||
if self.vectors is not None and hasattr(self.vectors, "key2row"):
|
||||
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
||||
|
@ -239,13 +244,13 @@ cdef class Vocab:
|
|||
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1:
|
||||
self._by_orth.set(lex.orth, <void*>lex)
|
||||
self.length += 1
|
||||
if is_transient:
|
||||
if is_transient and self.in_memory_zone:
|
||||
self._transient_orths.push_back(lex.orth)
|
||||
|
||||
def _clear_transient_orths(self):
|
||||
"""Remove transient lexemes from the index (generally at the end of the memory zone)"""
|
||||
for orth in self._transient_orths:
|
||||
self._by_orth.pop(orth)
|
||||
map_clear(self._by_orth.c_map, orth)
|
||||
self._transient_orths.clear()
|
||||
|
||||
def __contains__(self, key):
|
||||
|
|
Loading…
Reference in New Issue
Block a user