Fix tokenizer cache flushing (#7836)

* Fix tokenizer cache flushing

Fix/simplify tokenizer init detection in order to fix cache flushing
when properties are modified.

* Remove init reloading logic

* Remove logic disabling `_reload_special_cases` on init
  * Setting `rules` last in `__init__` (as before) means that setting
    other properties doesn't reload any special cases
  * Reset `rules` first in `from_bytes` so that setting other properties
    during deserialization doesn't reload any special cases
    unnecessarily
* Reset all properties in `Tokenizer.from_bytes` to allow any settings
  to be `None`

* Also reset special matcher when special cache is flushed

* Remove duplicate special case validation

* Add test for special cases flushing

* Extend test for tokenizer deserialization of None values
This commit is contained in:
Adriane Boyd 2021-04-22 10:14:57 +02:00 committed by GitHub
parent cfad7e21d5
commit f4339f9bff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 51 additions and 36 deletions

View File

@ -26,10 +26,14 @@ def test_serialize_custom_tokenizer(en_vocab, en_tokenizer):
assert tokenizer.rules != {} assert tokenizer.rules != {}
assert tokenizer.token_match is not None assert tokenizer.token_match is not None
assert tokenizer.url_match is not None assert tokenizer.url_match is not None
assert tokenizer.prefix_search is not None
assert tokenizer.infix_finditer is not None
tokenizer.from_bytes(tokenizer_bytes) tokenizer.from_bytes(tokenizer_bytes)
assert tokenizer.rules == {} assert tokenizer.rules == {}
assert tokenizer.token_match is None assert tokenizer.token_match is None
assert tokenizer.url_match is None assert tokenizer.url_match is None
assert tokenizer.prefix_search is None
assert tokenizer.infix_finditer is None
tokenizer = Tokenizer(en_vocab, rules={"ABC.": [{"ORTH": "ABC"}, {"ORTH": "."}]}) tokenizer = Tokenizer(en_vocab, rules={"ABC.": [{"ORTH": "ABC"}, {"ORTH": "."}]})
tokenizer.rules = {} tokenizer.rules = {}

View File

@ -1,4 +1,5 @@
import pytest import pytest
import re
from spacy.vocab import Vocab from spacy.vocab import Vocab
from spacy.tokenizer import Tokenizer from spacy.tokenizer import Tokenizer
from spacy.util import ensure_path from spacy.util import ensure_path
@ -186,3 +187,31 @@ def test_tokenizer_special_cases_spaces(tokenizer):
assert [t.text for t in tokenizer("a b c")] == ["a", "b", "c"] assert [t.text for t in tokenizer("a b c")] == ["a", "b", "c"]
tokenizer.add_special_case("a b c", [{"ORTH": "a b c"}]) tokenizer.add_special_case("a b c", [{"ORTH": "a b c"}])
assert [t.text for t in tokenizer("a b c")] == ["a b c"] assert [t.text for t in tokenizer("a b c")] == ["a b c"]
def test_tokenizer_flush_cache(en_vocab):
suffix_re = re.compile(r"[\.]$")
tokenizer = Tokenizer(
en_vocab,
suffix_search=suffix_re.search,
)
assert [t.text for t in tokenizer("a.")] == ["a", "."]
tokenizer.suffix_search = None
assert [t.text for t in tokenizer("a.")] == ["a."]
def test_tokenizer_flush_specials(en_vocab):
suffix_re = re.compile(r"[\.]$")
rules = {"a a": [{"ORTH": "a a"}]}
tokenizer1 = Tokenizer(
en_vocab,
suffix_search=suffix_re.search,
rules=rules,
)
tokenizer2 = Tokenizer(
en_vocab,
suffix_search=suffix_re.search,
)
assert [t.text for t in tokenizer1("a a.")] == ["a a", "."]
tokenizer1.rules = {}
assert [t.text for t in tokenizer1("a a.")] == ["a", "a", "."]

View File

@ -23,8 +23,8 @@ cdef class Tokenizer:
cdef object _infix_finditer cdef object _infix_finditer
cdef object _rules cdef object _rules
cdef PhraseMatcher _special_matcher cdef PhraseMatcher _special_matcher
cdef int _property_init_count cdef int _property_init_count # TODO: unused, remove in v3.1
cdef int _property_init_max cdef int _property_init_max # TODO: unused, remove in v3.1
cdef Doc _tokenize_affixes(self, unicode string, bint with_special_cases) cdef Doc _tokenize_affixes(self, unicode string, bint with_special_cases)
cdef int _apply_special_cases(self, Doc doc) except -1 cdef int _apply_special_cases(self, Doc doc) except -1

View File

@ -69,8 +69,6 @@ cdef class Tokenizer:
self._rules = {} self._rules = {}
self._special_matcher = PhraseMatcher(self.vocab) self._special_matcher = PhraseMatcher(self.vocab)
self._load_special_cases(rules) self._load_special_cases(rules)
self._property_init_count = 0
self._property_init_max = 4
property token_match: property token_match:
def __get__(self): def __get__(self):
@ -79,8 +77,6 @@ cdef class Tokenizer:
def __set__(self, token_match): def __set__(self, token_match):
self._token_match = token_match self._token_match = token_match
self._reload_special_cases() self._reload_special_cases()
if self._property_init_count <= self._property_init_max:
self._property_init_count += 1
property url_match: property url_match:
def __get__(self): def __get__(self):
@ -88,7 +84,7 @@ cdef class Tokenizer:
def __set__(self, url_match): def __set__(self, url_match):
self._url_match = url_match self._url_match = url_match
self._flush_cache() self._reload_special_cases()
property prefix_search: property prefix_search:
def __get__(self): def __get__(self):
@ -97,8 +93,6 @@ cdef class Tokenizer:
def __set__(self, prefix_search): def __set__(self, prefix_search):
self._prefix_search = prefix_search self._prefix_search = prefix_search
self._reload_special_cases() self._reload_special_cases()
if self._property_init_count <= self._property_init_max:
self._property_init_count += 1
property suffix_search: property suffix_search:
def __get__(self): def __get__(self):
@ -107,8 +101,6 @@ cdef class Tokenizer:
def __set__(self, suffix_search): def __set__(self, suffix_search):
self._suffix_search = suffix_search self._suffix_search = suffix_search
self._reload_special_cases() self._reload_special_cases()
if self._property_init_count <= self._property_init_max:
self._property_init_count += 1
property infix_finditer: property infix_finditer:
def __get__(self): def __get__(self):
@ -117,8 +109,6 @@ cdef class Tokenizer:
def __set__(self, infix_finditer): def __set__(self, infix_finditer):
self._infix_finditer = infix_finditer self._infix_finditer = infix_finditer
self._reload_special_cases() self._reload_special_cases()
if self._property_init_count <= self._property_init_max:
self._property_init_count += 1
property rules: property rules:
def __get__(self): def __get__(self):
@ -126,7 +116,7 @@ cdef class Tokenizer:
def __set__(self, rules): def __set__(self, rules):
self._rules = {} self._rules = {}
self._reset_cache([key for key in self._cache]) self._flush_cache()
self._flush_specials() self._flush_specials()
self._cache = PreshMap() self._cache = PreshMap()
self._specials = PreshMap() self._specials = PreshMap()
@ -226,6 +216,7 @@ cdef class Tokenizer:
self.mem.free(cached) self.mem.free(cached)
def _flush_specials(self): def _flush_specials(self):
self._special_matcher = PhraseMatcher(self.vocab)
for k in self._specials: for k in self._specials:
cached = <_Cached*>self._specials.get(k) cached = <_Cached*>self._specials.get(k)
del self._specials[k] del self._specials[k]
@ -568,7 +559,6 @@ cdef class Tokenizer:
"""Add special-case tokenization rules.""" """Add special-case tokenization rules."""
if special_cases is not None: if special_cases is not None:
for chunk, substrings in sorted(special_cases.items()): for chunk, substrings in sorted(special_cases.items()):
self._validate_special_case(chunk, substrings)
self.add_special_case(chunk, substrings) self.add_special_case(chunk, substrings)
def _validate_special_case(self, chunk, substrings): def _validate_special_case(self, chunk, substrings):
@ -616,16 +606,9 @@ cdef class Tokenizer:
self._special_matcher.add(string, None, self._tokenize_affixes(string, False)) self._special_matcher.add(string, None, self._tokenize_affixes(string, False))
def _reload_special_cases(self): def _reload_special_cases(self):
try: self._flush_cache()
self._property_init_count self._flush_specials()
except AttributeError: self._load_special_cases(self._rules)
return
# only reload if all 4 of prefix, suffix, infix, token_match have
# have been initialized
if self.vocab is not None and self._property_init_count >= self._property_init_max:
self._flush_cache()
self._flush_specials()
self._load_special_cases(self._rules)
def explain(self, text): def explain(self, text):
"""A debugging tokenizer that provides information about which """A debugging tokenizer that provides information about which
@ -811,6 +794,15 @@ cdef class Tokenizer:
"url_match": lambda b: data.setdefault("url_match", b), "url_match": lambda b: data.setdefault("url_match", b),
"exceptions": lambda b: data.setdefault("rules", b) "exceptions": lambda b: data.setdefault("rules", b)
} }
# reset all properties and flush all caches (through rules),
# reset rules first so that _reload_special_cases is trivial/fast as
# the other properties are reset
self.rules = {}
self.prefix_search = None
self.suffix_search = None
self.infix_finditer = None
self.token_match = None
self.url_match = None
msg = util.from_bytes(bytes_data, deserializers, exclude) msg = util.from_bytes(bytes_data, deserializers, exclude)
if "prefix_search" in data and isinstance(data["prefix_search"], str): if "prefix_search" in data and isinstance(data["prefix_search"], str):
self.prefix_search = re.compile(data["prefix_search"]).search self.prefix_search = re.compile(data["prefix_search"]).search
@ -818,22 +810,12 @@ cdef class Tokenizer:
self.suffix_search = re.compile(data["suffix_search"]).search self.suffix_search = re.compile(data["suffix_search"]).search
if "infix_finditer" in data and isinstance(data["infix_finditer"], str): if "infix_finditer" in data and isinstance(data["infix_finditer"], str):
self.infix_finditer = re.compile(data["infix_finditer"]).finditer self.infix_finditer = re.compile(data["infix_finditer"]).finditer
# for token_match and url_match, set to None to override the language
# defaults if no regex is provided
if "token_match" in data and isinstance(data["token_match"], str): if "token_match" in data and isinstance(data["token_match"], str):
self.token_match = re.compile(data["token_match"]).match self.token_match = re.compile(data["token_match"]).match
else:
self.token_match = None
if "url_match" in data and isinstance(data["url_match"], str): if "url_match" in data and isinstance(data["url_match"], str):
self.url_match = re.compile(data["url_match"]).match self.url_match = re.compile(data["url_match"]).match
else:
self.url_match = None
if "rules" in data and isinstance(data["rules"], dict): if "rules" in data and isinstance(data["rules"], dict):
# make sure to hard reset the cache to remove data from the default exceptions self.rules = data["rules"]
self._rules = {}
self._flush_cache()
self._flush_specials()
self._load_special_cases(data["rules"])
return self return self