mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Better handling of unexpected types in SetPredicate
(#11312)
* `Matcher`: Better type checking of values in `SetPredicate` `SetPredicate`: Emit warning and return `False` on unexpected value types * Rename `value_type_mismatch` variable * Inline warning * Remove unexpected type warning from `_SetPredicate` * Ensure that `str` values are not interpreted as sequences Check elements of sequence values for convertibility to `str` or `int` * Add more `INTERSECT` and `IN` test cases * Test for inputs with multiple characters * Return `False` early instead of using a boolean flag * Remove superfluous `int` check, parentheses * Apply suggestions from code review Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Appy suggestions from code review * Clarify test comment Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
78f5503a29
commit
d1760ebe02
|
@ -1,5 +1,5 @@
|
||||||
# cython: infer_types=True, cython: profile=True
|
# cython: infer_types=True, cython: profile=True
|
||||||
from typing import List
|
from typing import List, Iterable
|
||||||
|
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libc.stdint cimport int32_t, int8_t
|
from libc.stdint cimport int32_t, int8_t
|
||||||
|
@ -867,20 +867,27 @@ class _SetPredicate:
|
||||||
|
|
||||||
def __call__(self, Token token):
|
def __call__(self, Token token):
|
||||||
if self.is_extension:
|
if self.is_extension:
|
||||||
value = get_string_id(token._.get(self.attr))
|
value = token._.get(self.attr)
|
||||||
else:
|
else:
|
||||||
value = get_token_attr_for_matcher(token.c, self.attr)
|
value = get_token_attr_for_matcher(token.c, self.attr)
|
||||||
|
|
||||||
if self.predicate in ("IS_SUBSET", "IS_SUPERSET", "INTERSECTS"):
|
if self.predicate in ("IN", "NOT_IN"):
|
||||||
|
if isinstance(value, (str, int)):
|
||||||
|
value = get_string_id(value)
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
elif self.predicate in ("IS_SUBSET", "IS_SUPERSET", "INTERSECTS"):
|
||||||
|
# ensure that all values are enclosed in a set
|
||||||
if self.attr == MORPH:
|
if self.attr == MORPH:
|
||||||
# break up MORPH into individual Feat=Val values
|
# break up MORPH into individual Feat=Val values
|
||||||
value = set(get_string_id(v) for v in MorphAnalysis.from_id(self.vocab, value))
|
value = set(get_string_id(v) for v in MorphAnalysis.from_id(self.vocab, value))
|
||||||
|
elif isinstance(value, (str, int)):
|
||||||
|
value = set((get_string_id(value),))
|
||||||
|
elif isinstance(value, Iterable) and all(isinstance(v, (str, int)) for v in value):
|
||||||
|
value = set(get_string_id(v) for v in value)
|
||||||
else:
|
else:
|
||||||
# treat a single value as a list
|
return False
|
||||||
if isinstance(value, (str, int)):
|
|
||||||
value = set([get_string_id(value)])
|
|
||||||
else:
|
|
||||||
value = set(get_string_id(v) for v in value)
|
|
||||||
if self.predicate == "IN":
|
if self.predicate == "IN":
|
||||||
return value in self.value
|
return value in self.value
|
||||||
elif self.predicate == "NOT_IN":
|
elif self.predicate == "NOT_IN":
|
||||||
|
|
|
@ -368,6 +368,16 @@ def test_matcher_intersect_value_operator(en_vocab):
|
||||||
doc[0]._.ext = ["A", "B"]
|
doc[0]._.ext = ["A", "B"]
|
||||||
assert len(matcher(doc)) == 1
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
|
# INTERSECTS matches nothing for iterables that aren't all str or int
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
pattern = [{"_": {"ext": {"INTERSECTS": ["Abx", "C"]}}}]
|
||||||
|
matcher.add("M", [pattern])
|
||||||
|
doc = Doc(en_vocab, words=["a", "b", "c"])
|
||||||
|
doc[0]._.ext = [["Abx"], "B"]
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
doc[0]._.ext = ["Abx", "B"]
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
# INTERSECTS with an empty pattern list matches nothing
|
# INTERSECTS with an empty pattern list matches nothing
|
||||||
matcher = Matcher(en_vocab)
|
matcher = Matcher(en_vocab)
|
||||||
pattern = [{"_": {"ext": {"INTERSECTS": []}}}]
|
pattern = [{"_": {"ext": {"INTERSECTS": []}}}]
|
||||||
|
@ -476,14 +486,22 @@ def test_matcher_extension_set_membership(en_vocab):
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="IN predicate must handle sequence values in extensions")
|
|
||||||
def test_matcher_extension_in_set_predicate(en_vocab):
|
def test_matcher_extension_in_set_predicate(en_vocab):
|
||||||
matcher = Matcher(en_vocab)
|
matcher = Matcher(en_vocab)
|
||||||
Token.set_extension("ext", default=[])
|
Token.set_extension("ext", default=[])
|
||||||
pattern = [{"_": {"ext": {"IN": ["A", "C"]}}}]
|
pattern = [{"_": {"ext": {"IN": ["A", "C"]}}}]
|
||||||
matcher.add("M", [pattern])
|
matcher.add("M", [pattern])
|
||||||
doc = Doc(en_vocab, words=["a", "b", "c"])
|
doc = Doc(en_vocab, words=["a", "b", "c"])
|
||||||
|
|
||||||
|
# The IN predicate expects an exact match between the
|
||||||
|
# extension value and one of the pattern's values.
|
||||||
doc[0]._.ext = ["A", "B"]
|
doc[0]._.ext = ["A", "B"]
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
|
||||||
|
doc[0]._.ext = ["A"]
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
|
||||||
|
doc[0]._.ext = "A"
|
||||||
assert len(matcher(doc)) == 1
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user