handle sets inside regex operator

This commit is contained in:
Kevin Humphreys 2022-09-28 16:08:37 -07:00
parent 0da324ab5b
commit bf4b353ce5
3 changed files with 75 additions and 35 deletions

View File

@ -843,7 +843,8 @@ def _get_attr_values(spec, string_store):
class _FuzzyPredicate: class _FuzzyPredicate:
operators = ("FUZZY", "FUZZY1", "FUZZY2", "FUZZY3", "FUZZY4", "FUZZY5") operators = ("FUZZY", "FUZZY1", "FUZZY2", "FUZZY3", "FUZZY4", "FUZZY5")
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
regex=None, fuzzy=None):
self.i = i self.i = i
self.attr = attr self.attr = attr
self.value = value self.value = value
@ -852,8 +853,8 @@ class _FuzzyPredicate:
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
if self.predicate not in self.operators: if self.predicate not in self.operators:
raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate)) raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate))
self.distance = self.predicate[len('FUZZY'):] # number after prefix self.fuzzy = self.predicate[len('FUZZY'):] # number after prefix
self.distance = int(self.distance) if self.distance else 0 self.fuzzy = int(self.fuzzy) if self.fuzzy else 0
def __call__(self, Token token): def __call__(self, Token token):
if self.is_extension: if self.is_extension:
@ -862,13 +863,14 @@ class _FuzzyPredicate:
value = token.vocab.strings[get_token_attr_for_matcher(token.c, self.attr)] value = token.vocab.strings[get_token_attr_for_matcher(token.c, self.attr)]
if self.value == value: if self.value == value:
return True return True
return Matcher.fuzzy_match(value, self.value, self.distance, token) return Matcher.fuzzy_match(value, self.value, self.fuzzy, token)
class _RegexPredicate: class _RegexPredicate:
operators = ("REGEX",) operators = ("REGEX",)
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
regex=None, fuzzy=None):
self.i = i self.i = i
self.attr = attr self.attr = attr
self.value = re.compile(value) self.value = re.compile(value)
@ -889,16 +891,20 @@ class _RegexPredicate:
class _SetPredicate: class _SetPredicate:
operators = ("IN", "NOT_IN", "IS_SUBSET", "IS_SUPERSET", "INTERSECTS") operators = ("IN", "NOT_IN", "IS_SUBSET", "IS_SUPERSET", "INTERSECTS")
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
regex=None, fuzzy=None):
self.i = i self.i = i
self.attr = attr self.attr = attr
self.vocab = vocab self.vocab = vocab
self.distance = distance self.regex = regex
self.fuzzy = fuzzy
if self.attr == MORPH: if self.attr == MORPH:
# normalize morph strings # normalize morph strings
self.value = set(self.vocab.morphology.add(v) for v in value) self.value = set(self.vocab.morphology.add(v) for v in value)
else: else:
if self.distance is not None: if self.regex:
self.value = set(re.compile(v) for v in value)
elif self.fuzzy is not None:
# add to string store # add to string store
self.value = set(self.vocab.strings.add(v) for v in value) self.value = set(self.vocab.strings.add(v) for v in value)
else: else:
@ -933,23 +939,29 @@ class _SetPredicate:
return False return False
if self.predicate == "IN": if self.predicate == "IN":
if value in self.value: if self.regex:
value = self.vocab.strings[value]
return any(bool(v.search(value)) for v in self.value)
elif value in self.value:
return True return True
elif self.distance is not None: elif self.fuzzy is not None:
s1 = self.vocab.strings[value] value = self.vocab.strings[value]
for v in self.value: return any(Matcher.fuzzy_match(value, self.vocab.strings[v], self.fuzzy, token)
if Matcher.fuzzy_match(s1, self.vocab.strings[v], self.distance, token): for v in self.value)
return True else:
return False
elif self.predicate == "NOT_IN":
if value in self.value:
return False return False
elif self.distance is not None: elif self.predicate == "NOT_IN":
s1 = self.vocab.strings[value] if self.regex:
for v in self.value: value = self.vocab.strings[value]
if Matcher.fuzzy_match(s1, self.vocab.strings[v], self.distance, token): return not any(bool(v.search(value)) for v in self.value)
return False elif value in self.value:
return True return False
elif self.fuzzy is not None:
value = self.vocab.strings[value]
return not any(Matcher.fuzzy_match(value, self.vocab.strings[v], self.fuzzy, token)
for v in self.value)
else:
return True
elif self.predicate == "IS_SUBSET": elif self.predicate == "IS_SUBSET":
return value <= self.value return value <= self.value
elif self.predicate == "IS_SUPERSET": elif self.predicate == "IS_SUPERSET":
@ -964,7 +976,8 @@ class _SetPredicate:
class _ComparisonPredicate: class _ComparisonPredicate:
operators = ("==", "!=", ">=", "<=", ">", "<") operators = ("==", "!=", ">=", "<=", ">", "<")
def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None,
regex=None, fuzzy=None):
self.i = i self.i = i
self.attr = attr self.attr = attr
self.value = value self.value = value
@ -1036,7 +1049,7 @@ def _get_extra_predicates(spec, extra_predicates, vocab):
def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types, def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
extra_predicates, seen_predicates, distance=None): extra_predicates, seen_predicates, regex=None, fuzzy=None):
output = [] output = []
for type_, value in value_dict.items(): for type_, value in value_dict.items():
type_ = type_.upper() type_ = type_.upper()
@ -1045,16 +1058,24 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types,
warnings.warn(Warnings.W035.format(pattern=value_dict)) warnings.warn(Warnings.W035.format(pattern=value_dict))
# ignore unrecognized predicate type # ignore unrecognized predicate type
continue continue
elif cls == _RegexPredicate:
if isinstance(value, dict):
# add predicates inside regex operator
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
extra_predicates, seen_predicates,
regex=True))
continue
elif cls == _FuzzyPredicate: elif cls == _FuzzyPredicate:
distance = type_[len("FUZZY"):] # number after prefix fuzzy = type_[len("FUZZY"):] # number after prefix
distance = int(distance) if distance else 0 fuzzy = int(fuzzy) if fuzzy else 0
if isinstance(value, dict): if isinstance(value, dict):
# add predicates inside fuzzy operator # add predicates inside fuzzy operator
output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types, output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types,
extra_predicates, seen_predicates, extra_predicates, seen_predicates,
distance=distance)) fuzzy=fuzzy))
continue continue
predicate = cls(len(extra_predicates), attr, value, type_, vocab=vocab, distance=distance) predicate = cls(len(extra_predicates), attr, value, type_, vocab=vocab,
regex=regex, fuzzy=fuzzy)
# Don't create a redundant predicates. # Don't create a redundant predicates.
# This helps with efficiency, as we're caching the results. # This helps with efficiency, as we're caching the results.
if predicate.key in seen_predicates: if predicate.key in seen_predicates:

View File

@ -156,7 +156,7 @@ def validate_token_pattern(obj: list) -> List[str]:
class TokenPatternString(BaseModel): class TokenPatternString(BaseModel):
REGEX: Optional[StrictStr] = Field(None, alias="regex") REGEX: Union[StrictStr, "TokenPatternString"] = Field(None, alias="regex")
IN: Optional[List[StrictStr]] = Field(None, alias="in") IN: Optional[List[StrictStr]] = Field(None, alias="in")
NOT_IN: Optional[List[StrictStr]] = Field(None, alias="not_in") NOT_IN: Optional[List[StrictStr]] = Field(None, alias="not_in")
IS_SUBSET: Optional[List[StrictStr]] = Field(None, alias="is_subset") IS_SUBSET: Optional[List[StrictStr]] = Field(None, alias="is_subset")
@ -193,11 +193,6 @@ class TokenPatternNumber(BaseModel):
LEQ: Union[StrictInt, StrictFloat] = Field(None, alias="<=") LEQ: Union[StrictInt, StrictFloat] = Field(None, alias="<=")
GT: Union[StrictInt, StrictFloat] = Field(None, alias=">") GT: Union[StrictInt, StrictFloat] = Field(None, alias=">")
LT: Union[StrictInt, StrictFloat] = Field(None, alias="<") LT: Union[StrictInt, StrictFloat] = Field(None, alias="<")
FUZZY1: Optional[StrictStr] = Field(None, alias="fuzzy1")
FUZZY2: Optional[StrictStr] = Field(None, alias="fuzzy2")
FUZZY3: Optional[StrictStr] = Field(None, alias="fuzzy3")
FUZZY4: Optional[StrictStr] = Field(None, alias="fuzzy4")
FUZZY5: Optional[StrictStr] = Field(None, alias="fuzzy5")
class Config: class Config:
extra = "forbid" extra = "forbid"

View File

@ -629,6 +629,30 @@ def test_matcher_regex(en_vocab):
assert len(matches) == 0 assert len(matches) == 0
def test_matcher_regex_set_in(en_vocab):
matcher = Matcher(en_vocab)
pattern = [{"ORTH": {"REGEX": {"IN": [r"(?:a)", r"(?:an)"]}}}]
matcher.add("A_OR_AN", [pattern])
doc = Doc(en_vocab, words=["an", "a", "hi"])
matches = matcher(doc)
assert len(matches) == 2
doc = Doc(en_vocab, words=["bye"])
matches = matcher(doc)
assert len(matches) == 0
def test_matcher_regex_set_not_in(en_vocab):
matcher = Matcher(en_vocab)
pattern = [{"ORTH": {"REGEX": {"NOT_IN": [r"(?:a)", r"(?:an)"]}}}]
matcher.add("A_OR_AN", [pattern])
doc = Doc(en_vocab, words=["an", "a", "hi"])
matches = matcher(doc)
assert len(matches) == 1
doc = Doc(en_vocab, words=["bye"])
matches = matcher(doc)
assert len(matches) == 1
def test_matcher_regex_shape(en_vocab): def test_matcher_regex_shape(en_vocab):
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
pattern = [{"SHAPE": {"REGEX": r"^[^x]+$"}}] pattern = [{"SHAPE": {"REGEX": r"^[^x]+$"}}]