From bf4b353ce50a39f73004f8b807ea0cf514b3e558 Mon Sep 17 00:00:00 2001 From: Kevin Humphreys Date: Wed, 28 Sep 2022 16:08:37 -0700 Subject: [PATCH] handle sets inside regex operator --- spacy/matcher/matcher.pyx | 79 ++++++++++++++++--------- spacy/schemas.py | 7 +-- spacy/tests/matcher/test_matcher_api.py | 24 ++++++++ 3 files changed, 75 insertions(+), 35 deletions(-) diff --git a/spacy/matcher/matcher.pyx b/spacy/matcher/matcher.pyx index c49bb92de..2e3dfd2da 100644 --- a/spacy/matcher/matcher.pyx +++ b/spacy/matcher/matcher.pyx @@ -843,7 +843,8 @@ def _get_attr_values(spec, string_store): class _FuzzyPredicate: operators = ("FUZZY", "FUZZY1", "FUZZY2", "FUZZY3", "FUZZY4", "FUZZY5") - def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): + def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, + regex=None, fuzzy=None): self.i = i self.attr = attr self.value = value @@ -852,8 +853,8 @@ class _FuzzyPredicate: self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True)) if self.predicate not in self.operators: raise ValueError(Errors.E126.format(good=self.operators, bad=self.predicate)) - self.distance = self.predicate[len('FUZZY'):] # number after prefix - self.distance = int(self.distance) if self.distance else 0 + self.fuzzy = self.predicate[len('FUZZY'):] # number after prefix + self.fuzzy = int(self.fuzzy) if self.fuzzy else 0 def __call__(self, Token token): if self.is_extension: @@ -862,13 +863,14 @@ class _FuzzyPredicate: value = token.vocab.strings[get_token_attr_for_matcher(token.c, self.attr)] if self.value == value: return True - return Matcher.fuzzy_match(value, self.value, self.distance, token) + return Matcher.fuzzy_match(value, self.value, self.fuzzy, token) class _RegexPredicate: operators = ("REGEX",) - def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): + def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, + regex=None, fuzzy=None): self.i = i self.attr = attr self.value = re.compile(value) @@ -889,16 +891,20 @@ class _RegexPredicate: class _SetPredicate: operators = ("IN", "NOT_IN", "IS_SUBSET", "IS_SUPERSET", "INTERSECTS") - def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): + def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, + regex=None, fuzzy=None): self.i = i self.attr = attr self.vocab = vocab - self.distance = distance + self.regex = regex + self.fuzzy = fuzzy if self.attr == MORPH: # normalize morph strings self.value = set(self.vocab.morphology.add(v) for v in value) else: - if self.distance is not None: + if self.regex: + self.value = set(re.compile(v) for v in value) + elif self.fuzzy is not None: # add to string store self.value = set(self.vocab.strings.add(v) for v in value) else: @@ -933,23 +939,29 @@ class _SetPredicate: return False if self.predicate == "IN": - if value in self.value: + if self.regex: + value = self.vocab.strings[value] + return any(bool(v.search(value)) for v in self.value) + elif value in self.value: return True - elif self.distance is not None: - s1 = self.vocab.strings[value] - for v in self.value: - if Matcher.fuzzy_match(s1, self.vocab.strings[v], self.distance, token): - return True - return False - elif self.predicate == "NOT_IN": - if value in self.value: + elif self.fuzzy is not None: + value = self.vocab.strings[value] + return any(Matcher.fuzzy_match(value, self.vocab.strings[v], self.fuzzy, token) + for v in self.value) + else: return False - elif self.distance is not None: - s1 = self.vocab.strings[value] - for v in self.value: - if Matcher.fuzzy_match(s1, self.vocab.strings[v], self.distance, token): - return False - return True + elif self.predicate == "NOT_IN": + if self.regex: + value = self.vocab.strings[value] + return not any(bool(v.search(value)) for v in self.value) + elif value in self.value: + return False + elif self.fuzzy is not None: + value = self.vocab.strings[value] + return not any(Matcher.fuzzy_match(value, self.vocab.strings[v], self.fuzzy, token) + for v in self.value) + else: + return True elif self.predicate == "IS_SUBSET": return value <= self.value elif self.predicate == "IS_SUPERSET": @@ -964,7 +976,8 @@ class _SetPredicate: class _ComparisonPredicate: operators = ("==", "!=", ">=", "<=", ">", "<") - def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, distance=None): + def __init__(self, i, attr, value, predicate, is_extension=False, vocab=None, + regex=None, fuzzy=None): self.i = i self.attr = attr self.value = value @@ -1036,7 +1049,7 @@ def _get_extra_predicates(spec, extra_predicates, vocab): def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types, - extra_predicates, seen_predicates, distance=None): + extra_predicates, seen_predicates, regex=None, fuzzy=None): output = [] for type_, value in value_dict.items(): type_ = type_.upper() @@ -1045,16 +1058,24 @@ def _get_extra_predicates_dict(attr, value_dict, vocab, predicate_types, warnings.warn(Warnings.W035.format(pattern=value_dict)) # ignore unrecognized predicate type continue + elif cls == _RegexPredicate: + if isinstance(value, dict): + # add predicates inside regex operator + output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types, + extra_predicates, seen_predicates, + regex=True)) + continue elif cls == _FuzzyPredicate: - distance = type_[len("FUZZY"):] # number after prefix - distance = int(distance) if distance else 0 + fuzzy = type_[len("FUZZY"):] # number after prefix + fuzzy = int(fuzzy) if fuzzy else 0 if isinstance(value, dict): # add predicates inside fuzzy operator output.extend(_get_extra_predicates_dict(attr, value, vocab, predicate_types, extra_predicates, seen_predicates, - distance=distance)) + fuzzy=fuzzy)) continue - predicate = cls(len(extra_predicates), attr, value, type_, vocab=vocab, distance=distance) + predicate = cls(len(extra_predicates), attr, value, type_, vocab=vocab, + regex=regex, fuzzy=fuzzy) # Don't create a redundant predicates. # This helps with efficiency, as we're caching the results. if predicate.key in seen_predicates: diff --git a/spacy/schemas.py b/spacy/schemas.py index f2be4428b..cc2ce792b 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -156,7 +156,7 @@ def validate_token_pattern(obj: list) -> List[str]: class TokenPatternString(BaseModel): - REGEX: Optional[StrictStr] = Field(None, alias="regex") + REGEX: Union[StrictStr, "TokenPatternString"] = Field(None, alias="regex") IN: Optional[List[StrictStr]] = Field(None, alias="in") NOT_IN: Optional[List[StrictStr]] = Field(None, alias="not_in") IS_SUBSET: Optional[List[StrictStr]] = Field(None, alias="is_subset") @@ -193,11 +193,6 @@ class TokenPatternNumber(BaseModel): LEQ: Union[StrictInt, StrictFloat] = Field(None, alias="<=") GT: Union[StrictInt, StrictFloat] = Field(None, alias=">") LT: Union[StrictInt, StrictFloat] = Field(None, alias="<") - FUZZY1: Optional[StrictStr] = Field(None, alias="fuzzy1") - FUZZY2: Optional[StrictStr] = Field(None, alias="fuzzy2") - FUZZY3: Optional[StrictStr] = Field(None, alias="fuzzy3") - FUZZY4: Optional[StrictStr] = Field(None, alias="fuzzy4") - FUZZY5: Optional[StrictStr] = Field(None, alias="fuzzy5") class Config: extra = "forbid" diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index 3c50489b8..2fbdc7a4f 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -629,6 +629,30 @@ def test_matcher_regex(en_vocab): assert len(matches) == 0 +def test_matcher_regex_set_in(en_vocab): + matcher = Matcher(en_vocab) + pattern = [{"ORTH": {"REGEX": {"IN": [r"(?:a)", r"(?:an)"]}}}] + matcher.add("A_OR_AN", [pattern]) + doc = Doc(en_vocab, words=["an", "a", "hi"]) + matches = matcher(doc) + assert len(matches) == 2 + doc = Doc(en_vocab, words=["bye"]) + matches = matcher(doc) + assert len(matches) == 0 + + +def test_matcher_regex_set_not_in(en_vocab): + matcher = Matcher(en_vocab) + pattern = [{"ORTH": {"REGEX": {"NOT_IN": [r"(?:a)", r"(?:an)"]}}}] + matcher.add("A_OR_AN", [pattern]) + doc = Doc(en_vocab, words=["an", "a", "hi"]) + matches = matcher(doc) + assert len(matches) == 1 + doc = Doc(en_vocab, words=["bye"]) + matches = matcher(doc) + assert len(matches) == 1 + + def test_matcher_regex_shape(en_vocab): matcher = Matcher(en_vocab) pattern = [{"SHAPE": {"REGEX": r"^[^x]+$"}}]