Fix predicate keys and matching for SetPredicate with FUZZY and REGEX

This commit is contained in:
Adriane Boyd 2022-12-01 17:52:42 +01:00
parent 8a749fccbc
commit 0e2c28446d
2 changed files with 20 additions and 9 deletions

View File

@ -58,7 +58,9 @@ cdef class Matcher:
"""Create the Matcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
validate (bool): Validate all patterns added to this matcher.
fuzzy_compare (Callable[[str, str, int], bool]): The comparison method
for the FUZZY operators.
"""
self._extra_predicates = []
self._patterns = {}
@ -855,12 +857,12 @@ class _FuzzyPredicate:
self.value = value
self.predicate = predicate
self.is_extension = is_extension
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
fuzz = self.predicate[len("FUZZY"):] # number after prefix
self.fuzzy = int(fuzz) if fuzz else -1
self.fuzzy_compare = fuzzy_compare
self.key = (self.attr, self.fuzzy, self.predicate, srsly.json_dumps(value, sort_keys=True))
def __call__(self, Token token):
if self.is_extension:
@ -882,7 +884,7 @@ class _RegexPredicate:
self.value = re.compile(value)
self.predicate = predicate
self.is_extension = is_extension
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
self.key = (self.attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
@ -918,7 +920,7 @@ class _SetPredicate:
self.value = set(get_string_id(v) for v in value)
self.predicate = predicate
self.is_extension = is_extension
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
self.key = (self.attr, self.regex, self.fuzzy, self.predicate, srsly.json_dumps(value, sort_keys=True))
if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
@ -949,24 +951,24 @@ class _SetPredicate:
if self.regex:
value = self.vocab.strings[value]
return any(bool(v.search(value)) for v in self.value)
elif value in self.value:
return True
elif self.fuzzy is not None:
value = self.vocab.strings[value]
return any(self.fuzzy_compare(value, self.vocab.strings[v], self.fuzzy)
for v in self.value)
elif value in self.value:
return True
else:
return False
elif self.predicate == "NOT_IN":
if self.regex:
value = self.vocab.strings[value]
return not any(bool(v.search(value)) for v in self.value)
elif value in self.value:
return False
elif self.fuzzy is not None:
value = self.vocab.strings[value]
return not any(self.fuzzy_compare(value, self.vocab.strings[v], self.fuzzy)
for v in self.value)
elif value in self.value:
return False
else:
return True
elif self.predicate == "IS_SUBSET":
@ -990,7 +992,7 @@ class _ComparisonPredicate:
self.value = value
self.predicate = predicate
self.is_extension = is_extension
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
self.key = (self.attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))

View File

@ -141,6 +141,15 @@ def test_matcher_match_multi(matcher):
},
[(2, 4), (5, 6), (8, 9)],
),
# only the second pattern matches (check that predicate keys used for
# caching don't collide)
(
{
"A": [[{"ORTH": {"FUZZY": "Javascript"}}]],
"B": [[{"ORTH": {"FUZZY5": "Javascript"}}]],
},
[(8, 9)],
),
],
)
def test_matcher_match_fuzzy(en_vocab, rules, match_locs):