Better handling of unexpected types in SetPredicate (#11312)

* `Matcher`: Better type checking of values in `SetPredicate`
`SetPredicate`: Emit warning and return `False` on unexpected value types

* Rename `value_type_mismatch` variable

* Inline warning

* Remove unexpected type warning from `_SetPredicate`

* Ensure that `str` values are not interpreted as sequences
Check elements of sequence values for convertibility to `str` or `int`

* Add more `INTERSECT` and `IN` test cases

* Test for inputs with multiple characters

* Return `False` early instead of using a boolean flag

* Remove superfluous `int` check, parentheses

* Apply suggestions from code review

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Appy suggestions from code review

* Clarify test comment

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Madeesh Kannan 2022-09-02 09:09:48 +02:00 committed by GitHub
parent 78f5503a29
commit d1760ebe02
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 34 additions and 9 deletions

View File

@ -1,5 +1,5 @@
# cython: infer_types=True, cython: profile=True # cython: infer_types=True, cython: profile=True
from typing import List from typing import List, Iterable
from libcpp.vector cimport vector from libcpp.vector cimport vector
from libc.stdint cimport int32_t, int8_t from libc.stdint cimport int32_t, int8_t
@ -867,20 +867,27 @@ class _SetPredicate:
def __call__(self, Token token): def __call__(self, Token token):
if self.is_extension: if self.is_extension:
value = get_string_id(token._.get(self.attr)) value = token._.get(self.attr)
else: else:
value = get_token_attr_for_matcher(token.c, self.attr) value = get_token_attr_for_matcher(token.c, self.attr)
if self.predicate in ("IS_SUBSET", "IS_SUPERSET", "INTERSECTS"): if self.predicate in ("IN", "NOT_IN"):
if isinstance(value, (str, int)):
value = get_string_id(value)
else:
return False
elif self.predicate in ("IS_SUBSET", "IS_SUPERSET", "INTERSECTS"):
# ensure that all values are enclosed in a set
if self.attr == MORPH: if self.attr == MORPH:
# break up MORPH into individual Feat=Val values # break up MORPH into individual Feat=Val values
value = set(get_string_id(v) for v in MorphAnalysis.from_id(self.vocab, value)) value = set(get_string_id(v) for v in MorphAnalysis.from_id(self.vocab, value))
else: elif isinstance(value, (str, int)):
# treat a single value as a list value = set((get_string_id(value),))
if isinstance(value, (str, int)): elif isinstance(value, Iterable) and all(isinstance(v, (str, int)) for v in value):
value = set([get_string_id(value)])
else:
value = set(get_string_id(v) for v in value) value = set(get_string_id(v) for v in value)
else:
return False
if self.predicate == "IN": if self.predicate == "IN":
return value in self.value return value in self.value
elif self.predicate == "NOT_IN": elif self.predicate == "NOT_IN":

View File

@ -368,6 +368,16 @@ def test_matcher_intersect_value_operator(en_vocab):
doc[0]._.ext = ["A", "B"] doc[0]._.ext = ["A", "B"]
assert len(matcher(doc)) == 1 assert len(matcher(doc)) == 1
# INTERSECTS matches nothing for iterables that aren't all str or int
matcher = Matcher(en_vocab)
pattern = [{"_": {"ext": {"INTERSECTS": ["Abx", "C"]}}}]
matcher.add("M", [pattern])
doc = Doc(en_vocab, words=["a", "b", "c"])
doc[0]._.ext = [["Abx"], "B"]
assert len(matcher(doc)) == 0
doc[0]._.ext = ["Abx", "B"]
assert len(matcher(doc)) == 1
# INTERSECTS with an empty pattern list matches nothing # INTERSECTS with an empty pattern list matches nothing
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
pattern = [{"_": {"ext": {"INTERSECTS": []}}}] pattern = [{"_": {"ext": {"INTERSECTS": []}}}]
@ -476,14 +486,22 @@ def test_matcher_extension_set_membership(en_vocab):
assert len(matches) == 0 assert len(matches) == 0
@pytest.mark.xfail(reason="IN predicate must handle sequence values in extensions")
def test_matcher_extension_in_set_predicate(en_vocab): def test_matcher_extension_in_set_predicate(en_vocab):
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
Token.set_extension("ext", default=[]) Token.set_extension("ext", default=[])
pattern = [{"_": {"ext": {"IN": ["A", "C"]}}}] pattern = [{"_": {"ext": {"IN": ["A", "C"]}}}]
matcher.add("M", [pattern]) matcher.add("M", [pattern])
doc = Doc(en_vocab, words=["a", "b", "c"]) doc = Doc(en_vocab, words=["a", "b", "c"])
# The IN predicate expects an exact match between the
# extension value and one of the pattern's values.
doc[0]._.ext = ["A", "B"] doc[0]._.ext = ["A", "B"]
assert len(matcher(doc)) == 0
doc[0]._.ext = ["A"]
assert len(matcher(doc)) == 0
doc[0]._.ext = "A"
assert len(matcher(doc)) == 1 assert len(matcher(doc)) == 1