diff --git a/spacy/lang/kmr/__init__.py b/spacy/lang/kmr/__init__.py index 124321a8e..eee9e69d0 100644 --- a/spacy/lang/kmr/__init__.py +++ b/spacy/lang/kmr/__init__.py @@ -1,5 +1,5 @@ -from .lex_attrs import LEX_ATTRS from ...language import BaseDefaults, Language +from .lex_attrs import LEX_ATTRS from .stop_words import STOP_WORDS diff --git a/spacy/lang/kmr/lex_attrs.py b/spacy/lang/kmr/lex_attrs.py index 8927ef141..6b8020410 100644 --- a/spacy/lang/kmr/lex_attrs.py +++ b/spacy/lang/kmr/lex_attrs.py @@ -1,6 +1,5 @@ from ...attrs import LIKE_NUM - _num_words = [ "sifir", "yek", diff --git a/spacy/language.py b/spacy/language.py index 18d20c939..57b851481 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -5,7 +5,7 @@ import multiprocessing as mp import random import traceback import warnings -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from copy import deepcopy from dataclasses import dataclass from itertools import chain, cycle @@ -31,6 +31,7 @@ from typing import ( ) import srsly +from cymem.cymem import Pool from thinc.api import Config, CupyOps, Optimizer, get_current_ops from . import about, ty, util @@ -2091,6 +2092,38 @@ class Language: util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined] tok2vec.remove_listener(listener, pipe_name) + @contextmanager + def memory_zone(self, mem: Optional[Pool]=None) -> Iterator[Pool]: + """Begin a block where all resources allocated during the block will + be freed at the end of it. If a resources was created within the + memory zone block, accessing it outside the block is invalid. + Behaviour of this invalid access is undefined. Memory zones should + not be nested. + + The memory zone is helpful for services that need to process large + volumes of text with a defined memory budget. + + Example + ------- + >>> with nlp.memory_zone(): + ... for doc in nlp.pipe(texts): + ... process_my_doc(doc) + >>> # use_doc(doc) <-- Invalid: doc was allocated in the memory zone + """ + if mem is None: + mem = Pool() + # The ExitStack allows programmatic nested context managers. + # We don't know how many we need, so it would be awkward to have + # them as nested blocks. + with ExitStack() as stack: + contexts = [stack.enter_context(self.vocab.memory_zone(mem))] + if hasattr(self.tokenizer, "memory_zone"): + contexts.append(stack.enter_context(self.tokenizer.memory_zone(mem))) + for _, pipe in self.pipeline: + if hasattr(pipe, "memory_zone"): + contexts.append(stack.enter_context(pipe.memory_zone(mem))) + yield mem + def to_disk( self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> None: diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index e13754944..bedaaf9fe 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -203,7 +203,7 @@ cdef class ArcEagerGold: def __init__(self, ArcEager moves, StateClass stcls, Example example): self.mem = Pool() heads, labels = example.get_aligned_parse(projectivize=True) - labels = [example.x.vocab.strings.add(label) if label is not None else MISSING_DEP for label in labels] + labels = [example.x.vocab.strings.add(label, allow_transient=False) if label is not None else MISSING_DEP for label in labels] sent_starts = _get_aligned_sent_starts(example) assert len(heads) == len(labels) == len(sent_starts), (len(heads), len(labels), len(sent_starts)) self.c = create_gold_state(self.mem, stcls.c, heads, labels, sent_starts) diff --git a/spacy/pipeline/_parser_internals/nonproj.pyx b/spacy/pipeline/_parser_internals/nonproj.pyx index 7de19851e..9e3a21b81 100644 --- a/spacy/pipeline/_parser_internals/nonproj.pyx +++ b/spacy/pipeline/_parser_internals/nonproj.pyx @@ -183,7 +183,7 @@ cpdef deprojectivize(Doc doc): new_label, head_label = label.split(DELIMITER) new_head = _find_new_head(doc[i], head_label) doc.c[i].head = new_head.i - i - doc.c[i].dep = doc.vocab.strings.add(new_label) + doc.c[i].dep = doc.vocab.strings.add(new_label, allow_transient=False) set_children_from_heads(doc.c, 0, doc.length) return doc diff --git a/spacy/strings.pxd b/spacy/strings.pxd index bd5e0f135..b01585858 100644 --- a/spacy/strings.pxd +++ b/spacy/strings.pxd @@ -28,5 +28,4 @@ cdef class StringStore: cdef const Utf8Str* intern_unicode(self, str py_string, bint allow_transient) cdef const Utf8Str* _intern_utf8(self, char* utf8_string, int length, hash_t* precalculated_hash, bint allow_transient) cdef vector[hash_t] _transient_keys - cdef PreshMap _transient_map cdef Pool _non_temp_mem diff --git a/spacy/strings.pyx b/spacy/strings.pyx index 5e0bd90c6..b0f6cf5aa 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -8,6 +8,7 @@ from typing import Iterator, List, Optional from libc.stdint cimport uint32_t from libc.string cimport memcpy from murmurhash.mrmr cimport hash32, hash64 +from preshed.maps cimport map_clear import srsly @@ -125,10 +126,9 @@ cdef class StringStore: self.mem = Pool() self._non_temp_mem = self.mem self._map = PreshMap() - self._transient_map = None if strings is not None: for string in strings: - self.add(string) + self.add(string, allow_transient=False) def __getitem__(self, object string_or_id): """Retrieve a string from a given hash, or vice versa. @@ -158,17 +158,17 @@ cdef class StringStore: return SYMBOLS_BY_INT[str_hash] else: utf8str = self._map.get(str_hash) - if utf8str is NULL and self._transient_map is not None: - utf8str = self._transient_map.get(str_hash) + if utf8str is NULL: + raise KeyError(Errors.E018.format(hash_value=string_or_id)) + else: + return decode_Utf8Str(utf8str) else: # TODO: Raise an error instead utf8str = self._map.get(string_or_id) - if utf8str is NULL and self._transient_map is not None: - utf8str = self._transient_map.get(str_hash) - if utf8str is NULL: - raise KeyError(Errors.E018.format(hash_value=string_or_id)) - else: - return decode_Utf8Str(utf8str) + if utf8str is NULL: + raise KeyError(Errors.E018.format(hash_value=string_or_id)) + else: + return decode_Utf8Str(utf8str) def as_int(self, key): """If key is an int, return it; otherwise, get the int value.""" @@ -184,16 +184,12 @@ cdef class StringStore: else: return self[key] - def __reduce__(self): - strings = list(self.non_transient_keys()) - return (StringStore, (strings,), None, None, None) - def __len__(self) -> int: """The number of strings in the store. RETURNS (int): The number of strings in the store. """ - return self._keys.size() + self._transient_keys.size() + return self.keys.size() + self._transient_keys.size() @contextmanager def memory_zone(self, mem: Optional[Pool] = None) -> Pool: @@ -209,13 +205,13 @@ cdef class StringStore: if mem is None: mem = Pool() self.mem = mem - self._transient_map = PreshMap() yield mem - self.mem = self._non_temp_mem - self._transient_map = None + for key in self._transient_keys: + map_clear(self._map.c_map, key) self._transient_keys.clear() + self.mem = self._non_temp_mem - def add(self, string: str, allow_transient: bool = False) -> int: + def add(self, string: str, allow_transient: Optional[bool] = None) -> int: """Add a string to the StringStore. string (str): The string to add. @@ -226,6 +222,8 @@ cdef class StringStore: internally should not. RETURNS (uint64): The string's hash value. """ + if allow_transient is None: + allow_transient = self.mem is not self._non_temp_mem cdef hash_t str_hash if isinstance(string, str): if string in SYMBOLS_BY_STR: @@ -273,8 +271,6 @@ cdef class StringStore: # TODO: Raise an error instead if self._map.get(string_or_id) is not NULL: return True - elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL: - return True else: return False if str_hash < len(SYMBOLS_BY_INT): @@ -282,8 +278,6 @@ cdef class StringStore: else: if self._map.get(str_hash) is not NULL: return True - elif self._transient_map is not None and self._transient_map.get(string_or_id) is not NULL: - return True else: return False @@ -292,32 +286,21 @@ cdef class StringStore: YIELDS (str): A string in the store. """ - yield from self.non_transient_keys() - yield from self.transient_keys() - - def non_transient_keys(self) -> Iterator[str]: - """Iterate over the stored strings in insertion order. - - RETURNS: A list of strings. - """ cdef int i cdef hash_t key for i in range(self.keys.size()): key = self.keys[i] utf8str = self._map.get(key) yield decode_Utf8Str(utf8str) + for i in range(self._transient_keys.size()): + key = self._transient_keys[i] + utf8str = self._map.get(key) + yield decode_Utf8Str(utf8str) def __reduce__(self): strings = list(self) return (StringStore, (strings,), None, None, None) - def transient_keys(self) -> Iterator[str]: - if self._transient_map is None: - return [] - for i in range(self._transient_keys.size()): - utf8str = self._transient_map.get(self._transient_keys[i]) - yield decode_Utf8Str(utf8str) - def values(self) -> List[int]: """Iterate over the stored strings hashes in insertion order. @@ -327,12 +310,9 @@ cdef class StringStore: hashes = [None] * self._keys.size() for i in range(self._keys.size()): hashes[i] = self._keys[i] - if self._transient_map is not None: - transient_hashes = [None] * self._transient_keys.size() - for i in range(self._transient_keys.size()): - transient_hashes[i] = self._transient_keys[i] - else: - transient_hashes = [] + transient_hashes = [None] * self._transient_keys.size() + for i in range(self._transient_keys.size()): + transient_hashes[i] = self._transient_keys[i] return hashes + transient_hashes def to_disk(self, path): @@ -383,8 +363,10 @@ cdef class StringStore: def _reset_and_load(self, strings): self.mem = Pool() + self._non_temp_mem = self.mem self._map = PreshMap() self.keys.clear() + self._transient_keys.clear() for string in strings: self.add(string, allow_transient=False) @@ -401,19 +383,10 @@ cdef class StringStore: cdef Utf8Str* value = self._map.get(key) if value is not NULL: return value - if allow_transient and self._transient_map is not None: - # If we've already allocated a transient string, and now we - # want to intern it permanently, we'll end up with the string - # in both places. That seems fine -- I don't see why we need - # to remove it from the transient map. - value = self._transient_map.get(key) - if value is not NULL: - return value value = _allocate(self.mem, utf8_string, length) - if allow_transient and self._transient_map is not None: - self._transient_map.set(key, value) + self._map.set(key, value) + if allow_transient and self.mem is not self._non_temp_mem: self._transient_keys.push_back(key) else: - self._map.set(key, value) self.keys.push_back(key) return value diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx index 93b7f63ac..6ca170dd4 100644 --- a/spacy/tokenizer.pyx +++ b/spacy/tokenizer.pyx @@ -517,12 +517,8 @@ cdef class Tokenizer: if n <= 0: # avoid mem alloc of zero length return 0 - # Historically this check was mostly used to avoid caching - # chunks that had tokens owned by the Doc. Now that that's - # not a thing, I don't think we need this? - for i in range(n): - if self.vocab._by_orth.get(tokens[i].lex.orth) == NULL: - return 0 + if self.vocab.in_memory_zone: + return 0 # See #1250 if has_special[0]: return 0 diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 97ba5d68c..11043c17a 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -5,6 +5,7 @@ from typing import Iterator, Optional import numpy import srsly from thinc.api import get_array_module, get_current_ops +from preshed.maps cimport map_clear from .attrs cimport LANG, ORTH from .lexeme cimport EMPTY_LEXEME, OOV_RANK, Lexeme @@ -104,7 +105,7 @@ cdef class Vocab: def vectors(self, vectors): if hasattr(vectors, "strings"): for s in vectors.strings: - self.strings.add(s) + self.strings.add(s, allow_transient=False) self._vectors = vectors self._vectors.strings = self.strings @@ -115,6 +116,10 @@ cdef class Vocab: langfunc = self.lex_attr_getters.get(LANG, None) return langfunc("_") if langfunc else "" + @property + def in_memory_zone(self) -> bool: + return self.mem is not self._non_temp_mem + def __len__(self): """The current number of lexemes stored. @@ -218,7 +223,7 @@ cdef class Vocab: # this size heuristic. mem = self.mem lex = mem.alloc(1, sizeof(LexemeC)) - lex.orth = self.strings.add(string) + lex.orth = self.strings.add(string, allow_transient=True) lex.length = len(string) if self.vectors is not None and hasattr(self.vectors, "key2row"): lex.id = self.vectors.key2row.get(lex.orth, OOV_RANK) @@ -239,13 +244,13 @@ cdef class Vocab: cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex, bint is_transient) except -1: self._by_orth.set(lex.orth, lex) self.length += 1 - if is_transient: + if is_transient and self.in_memory_zone: self._transient_orths.push_back(lex.orth) def _clear_transient_orths(self): """Remove transient lexemes from the index (generally at the end of the memory zone)""" for orth in self._transient_orths: - self._by_orth.pop(orth) + map_clear(self._by_orth.c_map, orth) self._transient_orths.clear() def __contains__(self, key):