From 0e2c28446d1f2ebb7dce167f45e5f7c2eebc3b52 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 1 Dec 2022 17:52:42 +0100 Subject: [PATCH] Fix predicate keys and matching for SetPredicate with FUZZY and REGEX --- spacy/matcher/matcher.pyx | 20 +++++++++++--------- spacy/tests/matcher/test_matcher_api.py | 9 +++++++++ 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index ab04d5cda..2e41d24a5 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -58,7 +58,9 @@ cdef class Matcher: """Create the Matcher. vocab (Vocab): The vocabulary object, which must be shared with the - documents the matcher will operate on. + validate (bool): Validate all patterns added to this matcher. + fuzzy_compare (Callable[[str, str, int], bool]): The comparison method + for the FUZZY operators. """ self._extra_predicates = [] self._patterns = {} @@ -855,12 +857,12 @@ class _FuzzyPredicate: self.value = value self.predicate = predicate self.is_extension = is_extension - self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) if self.predicate not in self.operators: raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate)) fuzz = self.predicate[len("FUZZY"):] # number after prefix self.fuzzy = int(fuzz) if fuzz else -1 self.fuzzy_compare = fuzzy_compare + self.key = (self.attr, self.fuzzy, self.predicate, srsly.json_dumps(value, sort_keys=True)) def __call__(self, Token token): if self.is_extension: @@ -882,7 +884,7 @@ class _RegexPredicate: self.value = re.compile(value) self.predicate = predicate self.is_extension = is_extension - self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) + self.key = (self.attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) if self.predicate not in self.operators: raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate)) @@ -918,7 +920,7 @@ class _SetPredicate: self.value = set(get_string_id(v) for v in value) self.predicate = predicate self.is_extension = is_extension - self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) + self.key = (self.attr, self.regex, self.fuzzy, self.predicate, srsly.json_dumps(value, sort_keys=True)) if self.predicate not in self.operators: raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate)) @@ -949,24 +951,24 @@ class _SetPredicate: if self.regex: value = self.vocab.strings[value] return any(bool(v.search(value)) for v in self.value) - elif value in self.value: - return True elif self.fuzzy is not None: value = self.vocab.strings[value] return any(self.fuzzy_compare(value, self.vocab.strings[v], self.fuzzy) for v in self.value) + elif value in self.value: + return True else: return False elif self.predicate == "NOT_IN": if self.regex: value = self.vocab.strings[value] return not any(bool(v.search(value)) for v in self.value) - elif value in self.value: - return False elif self.fuzzy is not None: value = self.vocab.strings[value] return not any(self.fuzzy_compare(value, self.vocab.strings[v], self.fuzzy) for v in self.value) + elif value in self.value: + return False else: return True elif self.predicate == "IS_SUBSET": @@ -990,7 +992,7 @@ class _ComparisonPredicate: self.value = value self.predicate = predicate self.is_extension = is_extension - self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) + self.key = (self.attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) if self.predicate not in self.operators: raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate)) diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index 5daedff10..87b0f32e3 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -141,6 +141,15 @@ def test_matcher_match_multi(matcher): }, [(2, 4), (5, 6), (8, 9)], ), + # only the second pattern matches (check that predicate keys used for + # caching don't collide) + ( + { + "A": [[{"ORTH": {"FUZZY": "Javascript"}}]], + "B": [[{"ORTH": {"FUZZY5": "Javascript"}}]], + }, + [(8, 9)], + ), ], ) def test_matcher_match_fuzzy(en_vocab, rules, match_locs):