mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-04 21:23:22 +03:00
Fix memory zones
This commit is contained in:
parent
59ac7e6bdb
commit
a019315534
|
@ -1,5 +1,5 @@
|
||||||
from .lex_attrs import LEX_ATTRS
|
|
||||||
from ...language import BaseDefaults, Language
|
from ...language import BaseDefaults, Language
|
||||||
|
from .lex_attrs import LEX_ATTRS
|
||||||
from .stop_words import STOP_WORDS
|
from .stop_words import STOP_WORDS
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from ...attrs import LIKE_NUM
|
from ...attrs import LIKE_NUM
|
||||||
|
|
||||||
|
|
||||||
_num_words = [
|
_num_words = [
|
||||||
"sifir",
|
"sifir",
|
||||||
"yek",
|
"yek",
|
||||||
|
|
|
@ -5,7 +5,7 @@ import multiprocessing as mp
|
||||||
import random
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import ExitStack, contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from itertools import chain, cycle
|
from itertools import chain, cycle
|
||||||
|
@ -31,6 +31,7 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
|
from cymem.cymem import Pool
|
||||||
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
|
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
|
||||||
|
|
||||||
from . import about, ty, util
|
from . import about, ty, util
|
||||||
|
@ -2091,6 +2092,38 @@ class Language:
|
||||||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||||
tok2vec.remove_listener(listener, pipe_name)
|
tok2vec.remove_listener(listener, pipe_name)
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def memory_zone(self, mem: Optional[Pool]=None) -> Iterator[Pool]:
|
||||||
|
"""Begin a block where all resources allocated during the block will
|
||||||
|
be freed at the end of it. If a resources was created within the
|
||||||
|
memory zone block, accessing it outside the block is invalid.
|
||||||
|
Behaviour of this invalid access is undefined. Memory zones should
|
||||||
|
not be nested.
|
||||||
|
|
||||||
|
The memory zone is helpful for services that need to process large
|
||||||
|
volumes of text with a defined memory budget.
|
||||||
|
|
||||||
|
Example
|
||||||
|
-------
|
||||||
|
>>> with nlp.memory_zone():
|
||||||
|
... for doc in nlp.pipe(texts):
|
||||||
|
... process_my_doc(doc)
|
||||||
|
>>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone
|
||||||
|
"""
|
||||||
|
if mem is None:
|
||||||
|
mem = Pool()
|
||||||
|
# The ExitStack allows programmatic nested context managers.
|
||||||
|
# We don't know how many we need, so it would be awkward to have
|
||||||
|
# them as nested blocks.
|
||||||
|
with ExitStack() as stack:
|
||||||
|
contexts = [stack.enter_context(self.vocab.memory_zone(mem))]
|
||||||
|
if hasattr(self.tokenizer, "memory_zone"):
|
||||||
|
contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem)))
|
||||||
|
for _, pipe in self.pipeline:
|
||||||
|
if hasattr(pipe, "memory_zone"):
|
||||||
|
contexts.append(stack.enter_context(pipe.memory_zone(mem)))
|
||||||
|
yield mem
|
||||||
|
|
||||||
def to_disk(
|
def to_disk(
|
||||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -203,7 +203,7 @@ cdef class ArcEagerGold:
|
||||||
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
def __init__(self, ArcEager moves, StateClass stcls, Example example):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
heads, labels = example.get_aligned_parse(projectivize=True)
|
heads, labels = example.get_aligned_parse(projectivize=True)
|
||||||
labels = [example.x.vocab.strings.add(label) if label is not None else MISSING_DEP for label in labels]
|
labels = [example.x.vocab.strings.add(label, allow_transient=False) if label is not None else MISSING_DEP for label in labels]
|
||||||
sent_starts = _get_aligned_sent_starts(example)
|
sent_starts = _get_aligned_sent_starts(example)
|
||||||
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
|
assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts))
|
||||||
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
|
self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts)
|
||||||
|
|
|
@ -183,7 +183,7 @@ cpdef deprojectivize(Doc doc):
|
||||||
new_label, head_label = label.split(DELIMITER)
|
new_label, head_label = label.split(DELIMITER)
|
||||||
new_head = _find_new_head(doc[i], head_label)
|
new_head = _find_new_head(doc[i], head_label)
|
||||||
doc.c[i].head = new_head.i - i
|
doc.c[i].head = new_head.i - i
|
||||||
doc.c[i].dep = doc.vocab.strings.add(new_label)
|
doc.c[i].dep = doc.vocab.strings.add(new_label, allow_transient=False)
|
||||||
set_children_from_heads(doc.c, 0, doc.length)
|
set_children_from_heads(doc.c, 0, doc.length)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
|
|
@ -28,5 +28,4 @@ cdef class StringStore:
|
||||||
cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient)
|
cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient)
|
||||||
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient)
|
cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient)
|
||||||
cdef vector[hash_t] _transient_keys
|
cdef vector[hash_t] _transient_keys
|
||||||
cdef PreshMap _transient_map
|
|
||||||
cdef Pool _non_temp_mem
|
cdef Pool _non_temp_mem
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Iterator, List, Optional
|
||||||
from libc.stdint cimport uint32_t
|
from libc.stdint cimport uint32_t
|
||||||
from libc.string cimport memcpy
|
from libc.string cimport memcpy
|
||||||
from murmurhash.mrmr cimport hash32, hash64
|
from murmurhash.mrmr cimport hash32, hash64
|
||||||
|
from preshed.maps cimport map_clear
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
@ -125,10 +126,9 @@ cdef class StringStore:
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self._non_temp_mem = self.mem
|
self._non_temp_mem = self.mem
|
||||||
self._map = PreshMap()
|
self._map = PreshMap()
|
||||||
self._transient_map = None
|
|
||||||
if strings is not None:
|
if strings is not None:
|
||||||
for string in strings:
|
for string in strings:
|
||||||
self.add(string)
|
self.add(string, allow_transient=False)
|
||||||
|
|
||||||
def __getitem__(self, object string_or_id):
|
def __getitem__(self, object string_or_id):
|
||||||
"""Retrieve a string from a given hash, or vice versa.
|
"""Retrieve a string from a given hash, or vice versa.
|
||||||
|
@ -158,17 +158,17 @@ cdef class StringStore:
|
||||||
return SYMBOLS_BY_INT[str_hash]
|
return SYMBOLS_BY_INT[str_hash]
|
||||||
else:
|
else:
|
||||||
utf8str = <Utf8Str*>self._map.get(str_hash)
|
utf8str = <Utf8Str*>self._map.get(str_hash)
|
||||||
if utf8str is NULL and self._transient_map is not None:
|
if utf8str is NULL:
|
||||||
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
|
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
||||||
|
else:
|
||||||
|
return decode_Utf8Str(utf8str)
|
||||||
else:
|
else:
|
||||||
# TODO: Raise an error instead
|
# TODO: Raise an error instead
|
||||||
utf8str = <Utf8Str*>self._map.get(string_or_id)
|
utf8str = <Utf8Str*>self._map.get(string_or_id)
|
||||||
if utf8str is NULL and self._transient_map is not None:
|
if utf8str is NULL:
|
||||||
utf8str = <Utf8Str*>self._transient_map.get(str_hash)
|
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
||||||
if utf8str is NULL:
|
else:
|
||||||
raise KeyError(Errors.E018.format(hash_value=string_or_id))
|
return decode_Utf8Str(utf8str)
|
||||||
else:
|
|
||||||
return decode_Utf8Str(utf8str)
|
|
||||||
|
|
||||||
def as_int(self, key):
|
def as_int(self, key):
|
||||||
"""If key is an int, return it; otherwise, get the int value."""
|
"""If key is an int, return it; otherwise, get the int value."""
|
||||||
|
@ -184,16 +184,12 @@ cdef class StringStore:
|
||||||
else:
|
else:
|
||||||
return self[key]
|
return self[key]
|
||||||
|
|
||||||
def __reduce__(self):
|
|
||||||
strings = list(self.non_transient_keys())
|
|
||||||
return (StringStore, (strings,), None, None, None)
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""The number of strings in the store.
|
"""The number of strings in the store.
|
||||||
|
|
||||||
RETURNS (int): The number of strings in the store.
|
RETURNS (int): The number of strings in the store.
|
||||||
"""
|
"""
|
||||||
return self._keys.size() + self._transient_keys.size()
|
return self.keys.size() + self._transient_keys.size()
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
|
def memory_zone(self, mem: Optional[Pool] = None) -> Pool:
|
||||||
|
@ -209,13 +205,13 @@ cdef class StringStore:
|
||||||
if mem is None:
|
if mem is None:
|
||||||
mem = Pool()
|
mem = Pool()
|
||||||
self.mem = mem
|
self.mem = mem
|
||||||
self._transient_map = PreshMap()
|
|
||||||
yield mem
|
yield mem
|
||||||
self.mem = self._non_temp_mem
|
for key in self._transient_keys:
|
||||||
self._transient_map = None
|
map_clear(self._map.c_map, key)
|
||||||
self._transient_keys.clear()
|
self._transient_keys.clear()
|
||||||
|
self.mem = self._non_temp_mem
|
||||||
|
|
||||||
def add(self, string: str, allow_transient: bool = False) -> int:
|
def add(self, string: str, allow_transient: Optional[bool] = None) -> int:
|
||||||
"""Add a string to the StringStore.
|
"""Add a string to the StringStore.
|
||||||
|
|
||||||
string (str): The string to add.
|
string (str): The string to add.
|
||||||
|
@ -226,6 +222,8 @@ cdef class StringStore:
|
||||||
internally should not.
|
internally should not.
|
||||||
RETURNS (uint64): The string's hash value.
|
RETURNS (uint64): The string's hash value.
|
||||||
"""
|
"""
|
||||||
|
if allow_transient is None:
|
||||||
|
allow_transient = self.mem is not self._non_temp_mem
|
||||||
cdef hash_t str_hash
|
cdef hash_t str_hash
|
||||||
if isinstance(string, str):
|
if isinstance(string, str):
|
||||||
if string in SYMBOLS_BY_STR:
|
if string in SYMBOLS_BY_STR:
|
||||||
|
@ -273,8 +271,6 @@ cdef class StringStore:
|
||||||
# TODO: Raise an error instead
|
# TODO: Raise an error instead
|
||||||
if self._map.get(string_or_id) is not NULL:
|
if self._map.get(string_or_id) is not NULL:
|
||||||
return True
|
return True
|
||||||
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
|
|
||||||
return True
|
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
if str_hash < len(SYMBOLS_BY_INT):
|
if str_hash < len(SYMBOLS_BY_INT):
|
||||||
|
@ -282,8 +278,6 @@ cdef class StringStore:
|
||||||
else:
|
else:
|
||||||
if self._map.get(str_hash) is not NULL:
|
if self._map.get(str_hash) is not NULL:
|
||||||
return True
|
return True
|
||||||
elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL:
|
|
||||||
return True
|
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -292,32 +286,21 @@ cdef class StringStore:
|
||||||
|
|
||||||
YIELDS (str): A string in the store.
|
YIELDS (str): A string in the store.
|
||||||
"""
|
"""
|
||||||
yield from self.non_transient_keys()
|
|
||||||
yield from self.transient_keys()
|
|
||||||
|
|
||||||
def non_transient_keys(self) -> Iterator[str]:
|
|
||||||
"""Iterate over the stored strings in insertion order.
|
|
||||||
|
|
||||||
RETURNS: A list of strings.
|
|
||||||
"""
|
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef hash_t key
|
cdef hash_t key
|
||||||
for i in range(self.keys.size()):
|
for i in range(self.keys.size()):
|
||||||
key = self.keys[i]
|
key = self.keys[i]
|
||||||
utf8str = <Utf8Str*>self._map.get(key)
|
utf8str = <Utf8Str*>self._map.get(key)
|
||||||
yield decode_Utf8Str(utf8str)
|
yield decode_Utf8Str(utf8str)
|
||||||
|
for i in range(self._transient_keys.size()):
|
||||||
|
key = self._transient_keys[i]
|
||||||
|
utf8str = <Utf8Str*>self._map.get(key)
|
||||||
|
yield decode_Utf8Str(utf8str)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
strings = list(self)
|
strings = list(self)
|
||||||
return (StringStore, (strings,), None, None, None)
|
return (StringStore, (strings,), None, None, None)
|
||||||
|
|
||||||
def transient_keys(self) -> Iterator[str]:
|
|
||||||
if self._transient_map is None:
|
|
||||||
return []
|
|
||||||
for i in range(self._transient_keys.size()):
|
|
||||||
utf8str = <Utf8Str*>self._transient_map.get(self._transient_keys[i])
|
|
||||||
yield decode_Utf8Str(utf8str)
|
|
||||||
|
|
||||||
def values(self) -> List[int]:
|
def values(self) -> List[int]:
|
||||||
"""Iterate over the stored strings hashes in insertion order.
|
"""Iterate over the stored strings hashes in insertion order.
|
||||||
|
|
||||||
|
@ -327,12 +310,9 @@ cdef class StringStore:
|
||||||
hashes = [None] * self._keys.size()
|
hashes = [None] * self._keys.size()
|
||||||
for i in range(self._keys.size()):
|
for i in range(self._keys.size()):
|
||||||
hashes[i] = self._keys[i]
|
hashes[i] = self._keys[i]
|
||||||
if self._transient_map is not None:
|
transient_hashes = [None] * self._transient_keys.size()
|
||||||
transient_hashes = [None] * self._transient_keys.size()
|
for i in range(self._transient_keys.size()):
|
||||||
for i in range(self._transient_keys.size()):
|
transient_hashes[i] = self._transient_keys[i]
|
||||||
transient_hashes[i] = self._transient_keys[i]
|
|
||||||
else:
|
|
||||||
transient_hashes = []
|
|
||||||
return hashes + transient_hashes
|
return hashes + transient_hashes
|
||||||
|
|
||||||
def to_disk(self, path):
|
def to_disk(self, path):
|
||||||
|
@ -383,8 +363,10 @@ cdef class StringStore:
|
||||||
|
|
||||||
def _reset_and_load(self, strings):
|
def _reset_and_load(self, strings):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
|
self._non_temp_mem = self.mem
|
||||||
self._map = PreshMap()
|
self._map = PreshMap()
|
||||||
self.keys.clear()
|
self.keys.clear()
|
||||||
|
self._transient_keys.clear()
|
||||||
for string in strings:
|
for string in strings:
|
||||||
self.add(string, allow_transient=False)
|
self.add(string, allow_transient=False)
|
||||||
|
|
||||||
|
@ -401,19 +383,10 @@ cdef class StringStore:
|
||||||
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
cdef Utf8Str* value = <Utf8Str*>self._map.get(key)
|
||||||
if value is not NULL:
|
if value is not NULL:
|
||||||
return value
|
return value
|
||||||
if allow_transient and self._transient_map is not None:
|
|
||||||
# If we've already allocated a transient string, and now we
|
|
||||||
# want to intern it permanently, we'll end up with the string
|
|
||||||
# in both places. That seems fine -- I don't see why we need
|
|
||||||
# to remove it from the transient map.
|
|
||||||
value = <Utf8Str*>self._transient_map.get(key)
|
|
||||||
if value is not NULL:
|
|
||||||
return value
|
|
||||||
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
|
value = _allocate(self.mem, <unsigned char*>utf8_string, length)
|
||||||
if allow_transient and self._transient_map is not None:
|
self._map.set(key, value)
|
||||||
self._transient_map.set(key, value)
|
if allow_transient and self.mem is not self._non_temp_mem:
|
||||||
self._transient_keys.push_back(key)
|
self._transient_keys.push_back(key)
|
||||||
else:
|
else:
|
||||||
self._map.set(key, value)
|
|
||||||
self.keys.push_back(key)
|
self.keys.push_back(key)
|
||||||
return value
|
return value
|
||||||
|
|
|
@ -517,12 +517,8 @@ cdef class Tokenizer:
|
||||||
if n <= 0:
|
if n <= 0:
|
||||||
# avoid mem alloc of zero length
|
# avoid mem alloc of zero length
|
||||||
return 0
|
return 0
|
||||||
# Historically this check was mostly used to avoid caching
|
if self.vocab.in_memory_zone:
|
||||||
# chunks that had tokens owned by the Doc. Now that that's
|
return 0
|
||||||
# not a thing, I don't think we need this?
|
|
||||||
for i in range(n):
|
|
||||||
if self.vocab._by_orth.get(tokens[i].lex.orth) == NULL:
|
|
||||||
return 0
|
|
||||||
# See #1250
|
# See #1250
|
||||||
if has_special[0]:
|
if has_special[0]:
|
||||||
return 0
|
return 0
|
||||||
|
|
|
@ -5,6 +5,7 @@ from typing import Iterator, Optional
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import get_array_module, get_current_ops
|
from thinc.api import get_array_module, get_current_ops
|
||||||
|
from preshed.maps cimport map_clear
|
||||||
|
|
||||||
from .attrs cimport LANG, ORTH
|
from .attrs cimport LANG, ORTH
|
||||||
from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme
|
from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme
|
||||||
|
@ -104,7 +105,7 @@ cdef class Vocab:
|
||||||
def vectors(self, vectors):
|
def vectors(self, vectors):
|
||||||
if hasattr(vectors, "strings"):
|
if hasattr(vectors, "strings"):
|
||||||
for s in vectors.strings:
|
for s in vectors.strings:
|
||||||
self.strings.add(s)
|
self.strings.add(s, allow_transient=False)
|
||||||
self._vectors = vectors
|
self._vectors = vectors
|
||||||
self._vectors.strings = self.strings
|
self._vectors.strings = self.strings
|
||||||
|
|
||||||
|
@ -115,6 +116,10 @@ cdef class Vocab:
|
||||||
langfunc = self.lex_attr_getters.get(LANG, None)
|
langfunc = self.lex_attr_getters.get(LANG, None)
|
||||||
return langfunc("_") if langfunc else ""
|
return langfunc("_") if langfunc else ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def in_memory_zone(self) -> bool:
|
||||||
|
return self.mem is not self._non_temp_mem
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
"""The current number of lexemes stored.
|
"""The current number of lexemes stored.
|
||||||
|
|
||||||
|
@ -218,7 +223,7 @@ cdef class Vocab:
|
||||||
# this size heuristic.
|
# this size heuristic.
|
||||||
mem = self.mem
|
mem = self.mem
|
||||||
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
|
lex = <LexemeC*>mem.alloc(1, sizeof(LexemeC))
|
||||||
lex.orth = self.strings.add(string)
|
lex.orth = self.strings.add(string, allow_transient=True)
|
||||||
lex.length = len(string)
|
lex.length = len(string)
|
||||||
if self.vectors is not None and hasattr(self.vectors, "key2row"):
|
if self.vectors is not None and hasattr(self.vectors, "key2row"):
|
||||||
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK)
|
||||||
|
@ -239,13 +244,13 @@ cdef class Vocab:
|
||||||
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1:
|
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1:
|
||||||
self._by_orth.set(lex.orth, <void*>lex)
|
self._by_orth.set(lex.orth, <void*>lex)
|
||||||
self.length += 1
|
self.length += 1
|
||||||
if is_transient:
|
if is_transient and self.in_memory_zone:
|
||||||
self._transient_orths.push_back(lex.orth)
|
self._transient_orths.push_back(lex.orth)
|
||||||
|
|
||||||
def _clear_transient_orths(self):
|
def _clear_transient_orths(self):
|
||||||
"""Remove transient lexemes from the index (generally at the end of the memory zone)"""
|
"""Remove transient lexemes from the index (generally at the end of the memory zone)"""
|
||||||
for orth in self._transient_orths:
|
for orth in self._transient_orths:
|
||||||
self._by_orth.pop(orth)
|
map_clear(self._by_orth.c_map, orth)
|
||||||
self._transient_orths.clear()
|
self._transient_orths.clear()
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user