mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Better handling of unexpected types in SetPredicate
(#11312)
* `Matcher`: Better type checking of values in `SetPredicate` `SetPredicate`: Emit warning and return `False` on unexpected value types * Rename `value_type_mismatch` variable * Inline warning * Remove unexpected type warning from `_SetPredicate` * Ensure that `str` values are not interpreted as sequences Check elements of sequence values for convertibility to `str` or `int` * Add more `INTERSECT` and `IN` test cases * Test for inputs with multiple characters * Return `False` early instead of using a boolean flag * Remove superfluous `int` check, parentheses * Apply suggestions from code review Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Appy suggestions from code review * Clarify test comment Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
78f5503a29
commit
d1760ebe02
|
@ -1,5 +1,5 @@
|
|||
# cython: infer_types=True, cython: profile=True
|
||||
from typing import List
|
||||
from typing import List, Iterable
|
||||
|
||||
from libcpp.vector cimport vector
|
||||
from libc.stdint cimport int32_t, int8_t
|
||||
|
@ -867,20 +867,27 @@ class _SetPredicate:
|
|||
|
||||
def __call__(self, Token token):
|
||||
if self.is_extension:
|
||||
value = get_string_id(token._.get(self.attr))
|
||||
value = token._.get(self.attr)
|
||||
else:
|
||||
value = get_token_attr_for_matcher(token.c, self.attr)
|
||||
|
||||
if self.predicate in ("IS_SUBSET", "IS_SUPERSET", "INTERSECTS"):
|
||||
if self.predicate in ("IN", "NOT_IN"):
|
||||
if isinstance(value, (str, int)):
|
||||
value = get_string_id(value)
|
||||
else:
|
||||
return False
|
||||
elif self.predicate in ("IS_SUBSET", "IS_SUPERSET", "INTERSECTS"):
|
||||
# ensure that all values are enclosed in a set
|
||||
if self.attr == MORPH:
|
||||
# break up MORPH into individual Feat=Val values
|
||||
value = set(get_string_id(v) for v in MorphAnalysis.from_id(self.vocab, value))
|
||||
elif isinstance(value, (str, int)):
|
||||
value = set((get_string_id(value),))
|
||||
elif isinstance(value, Iterable) and all(isinstance(v, (str, int)) for v in value):
|
||||
value = set(get_string_id(v) for v in value)
|
||||
else:
|
||||
# treat a single value as a list
|
||||
if isinstance(value, (str, int)):
|
||||
value = set([get_string_id(value)])
|
||||
else:
|
||||
value = set(get_string_id(v) for v in value)
|
||||
return False
|
||||
|
||||
if self.predicate == "IN":
|
||||
return value in self.value
|
||||
elif self.predicate == "NOT_IN":
|
||||
|
|
|
@ -368,6 +368,16 @@ def test_matcher_intersect_value_operator(en_vocab):
|
|||
doc[0]._.ext = ["A", "B"]
|
||||
assert len(matcher(doc)) == 1
|
||||
|
||||
# INTERSECTS matches nothing for iterables that aren't all str or int
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{"_": {"ext": {"INTERSECTS": ["Abx", "C"]}}}]
|
||||
matcher.add("M", [pattern])
|
||||
doc = Doc(en_vocab, words=["a", "b", "c"])
|
||||
doc[0]._.ext = [["Abx"], "B"]
|
||||
assert len(matcher(doc)) == 0
|
||||
doc[0]._.ext = ["Abx", "B"]
|
||||
assert len(matcher(doc)) == 1
|
||||
|
||||
# INTERSECTS with an empty pattern list matches nothing
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{"_": {"ext": {"INTERSECTS": []}}}]
|
||||
|
@ -476,14 +486,22 @@ def test_matcher_extension_set_membership(en_vocab):
|
|||
assert len(matches) == 0
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="IN predicate must handle sequence values in extensions")
|
||||
def test_matcher_extension_in_set_predicate(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
Token.set_extension("ext", default=[])
|
||||
pattern = [{"_": {"ext": {"IN": ["A", "C"]}}}]
|
||||
matcher.add("M", [pattern])
|
||||
doc = Doc(en_vocab, words=["a", "b", "c"])
|
||||
|
||||
# The IN predicate expects an exact match between the
|
||||
# extension value and one of the pattern's values.
|
||||
doc[0]._.ext = ["A", "B"]
|
||||
assert len(matcher(doc)) == 0
|
||||
|
||||
doc[0]._.ext = ["A"]
|
||||
assert len(matcher(doc)) == 0
|
||||
|
||||
doc[0]._.ext = "A"
|
||||
assert len(matcher(doc)) == 1
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user