mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 12:18:04 +03:00
348d1829c7
* Preserve user data for DependencyMatcher on spans * Clean underscore in test * Modify test to use extensions stored in user data
372 lines
11 KiB
Python
372 lines
11 KiB
Python
import pytest
|
|
import pickle
|
|
import re
|
|
import copy
|
|
from mock import Mock
|
|
from spacy.matcher import DependencyMatcher
|
|
from spacy.tokens import Doc, Token
|
|
|
|
from ..doc.test_underscore import clean_underscore # noqa: F401
|
|
|
|
|
|
@pytest.fixture
|
|
def doc(en_vocab):
|
|
words = ["The", "quick", "brown", "fox", "jumped", "over", "the", "lazy", "fox"]
|
|
heads = [3, 3, 3, 4, 4, 4, 8, 8, 5]
|
|
deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"]
|
|
return Doc(en_vocab, words=words, heads=heads, deps=deps)
|
|
|
|
|
|
@pytest.fixture
|
|
def patterns(en_vocab):
|
|
def is_brown_yellow(text):
|
|
return bool(re.compile(r"brown|yellow").match(text))
|
|
|
|
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
|
|
|
|
pattern1 = [
|
|
{"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
|
|
{
|
|
"LEFT_ID": "fox",
|
|
"REL_OP": ">",
|
|
"RIGHT_ID": "q",
|
|
"RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
|
|
},
|
|
{
|
|
"LEFT_ID": "fox",
|
|
"REL_OP": ">",
|
|
"RIGHT_ID": "r",
|
|
"RIGHT_ATTRS": {IS_BROWN_YELLOW: True},
|
|
},
|
|
]
|
|
|
|
pattern2 = [
|
|
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
|
|
{
|
|
"LEFT_ID": "jumped",
|
|
"REL_OP": ">",
|
|
"RIGHT_ID": "fox1",
|
|
"RIGHT_ATTRS": {"ORTH": "fox"},
|
|
},
|
|
{
|
|
"LEFT_ID": "jumped",
|
|
"REL_OP": ".",
|
|
"RIGHT_ID": "over",
|
|
"RIGHT_ATTRS": {"ORTH": "over"},
|
|
},
|
|
]
|
|
|
|
pattern3 = [
|
|
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
|
|
{
|
|
"LEFT_ID": "jumped",
|
|
"REL_OP": ">",
|
|
"RIGHT_ID": "fox",
|
|
"RIGHT_ATTRS": {"ORTH": "fox"},
|
|
},
|
|
{
|
|
"LEFT_ID": "fox",
|
|
"REL_OP": ">>",
|
|
"RIGHT_ID": "r",
|
|
"RIGHT_ATTRS": {"ORTH": "brown"},
|
|
},
|
|
]
|
|
|
|
pattern4 = [
|
|
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
|
|
{
|
|
"LEFT_ID": "jumped",
|
|
"REL_OP": ">",
|
|
"RIGHT_ID": "fox",
|
|
"RIGHT_ATTRS": {"ORTH": "fox"},
|
|
},
|
|
]
|
|
|
|
pattern5 = [
|
|
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
|
|
{
|
|
"LEFT_ID": "jumped",
|
|
"REL_OP": ">>",
|
|
"RIGHT_ID": "fox",
|
|
"RIGHT_ATTRS": {"ORTH": "fox"},
|
|
},
|
|
]
|
|
|
|
return [pattern1, pattern2, pattern3, pattern4, pattern5]
|
|
|
|
|
|
@pytest.fixture
|
|
def dependency_matcher(en_vocab, patterns, doc):
|
|
matcher = DependencyMatcher(en_vocab)
|
|
mock = Mock()
|
|
for i in range(1, len(patterns) + 1):
|
|
if i == 1:
|
|
matcher.add("pattern1", [patterns[0]], on_match=mock)
|
|
else:
|
|
matcher.add("pattern" + str(i), [patterns[i - 1]])
|
|
|
|
return matcher
|
|
|
|
|
|
def test_dependency_matcher(dependency_matcher, doc, patterns):
|
|
assert len(dependency_matcher) == 5
|
|
assert "pattern3" in dependency_matcher
|
|
assert dependency_matcher.get("pattern3") == (None, [patterns[2]])
|
|
matches = dependency_matcher(doc)
|
|
assert len(matches) == 6
|
|
assert matches[0][1] == [3, 1, 2]
|
|
assert matches[1][1] == [4, 3, 5]
|
|
assert matches[2][1] == [4, 3, 2]
|
|
assert matches[3][1] == [4, 3]
|
|
assert matches[4][1] == [4, 3]
|
|
assert matches[5][1] == [4, 8]
|
|
|
|
span = doc[0:6]
|
|
matches = dependency_matcher(span)
|
|
assert len(matches) == 5
|
|
assert matches[0][1] == [3, 1, 2]
|
|
assert matches[1][1] == [4, 3, 5]
|
|
assert matches[2][1] == [4, 3, 2]
|
|
assert matches[3][1] == [4, 3]
|
|
assert matches[4][1] == [4, 3]
|
|
|
|
|
|
def test_dependency_matcher_pickle(en_vocab, patterns, doc):
|
|
matcher = DependencyMatcher(en_vocab)
|
|
for i in range(1, len(patterns) + 1):
|
|
matcher.add("pattern" + str(i), [patterns[i - 1]])
|
|
|
|
matches = matcher(doc)
|
|
assert matches[0][1] == [3, 1, 2]
|
|
assert matches[1][1] == [4, 3, 5]
|
|
assert matches[2][1] == [4, 3, 2]
|
|
assert matches[3][1] == [4, 3]
|
|
assert matches[4][1] == [4, 3]
|
|
assert matches[5][1] == [4, 8]
|
|
|
|
b = pickle.dumps(matcher)
|
|
matcher_r = pickle.loads(b)
|
|
|
|
assert len(matcher) == len(matcher_r)
|
|
matches = matcher_r(doc)
|
|
assert matches[0][1] == [3, 1, 2]
|
|
assert matches[1][1] == [4, 3, 5]
|
|
assert matches[2][1] == [4, 3, 2]
|
|
assert matches[3][1] == [4, 3]
|
|
assert matches[4][1] == [4, 3]
|
|
assert matches[5][1] == [4, 8]
|
|
|
|
|
|
def test_dependency_matcher_pattern_validation(en_vocab):
|
|
pattern = [
|
|
{"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
|
|
{
|
|
"LEFT_ID": "fox",
|
|
"REL_OP": ">",
|
|
"RIGHT_ID": "q",
|
|
"RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
|
|
},
|
|
{
|
|
"LEFT_ID": "fox",
|
|
"REL_OP": ">",
|
|
"RIGHT_ID": "r",
|
|
"RIGHT_ATTRS": {"ORTH": "brown"},
|
|
},
|
|
]
|
|
|
|
matcher = DependencyMatcher(en_vocab)
|
|
# original pattern is valid
|
|
matcher.add("FOUNDED", [pattern])
|
|
# individual pattern not wrapped in a list
|
|
with pytest.raises(ValueError):
|
|
matcher.add("FOUNDED", pattern)
|
|
# no anchor node
|
|
with pytest.raises(ValueError):
|
|
matcher.add("FOUNDED", [pattern[1:]])
|
|
# required keys missing
|
|
with pytest.raises(ValueError):
|
|
pattern2 = copy.deepcopy(pattern)
|
|
del pattern2[0]["RIGHT_ID"]
|
|
matcher.add("FOUNDED", [pattern2])
|
|
with pytest.raises(ValueError):
|
|
pattern2 = copy.deepcopy(pattern)
|
|
del pattern2[1]["RIGHT_ID"]
|
|
matcher.add("FOUNDED", [pattern2])
|
|
with pytest.raises(ValueError):
|
|
pattern2 = copy.deepcopy(pattern)
|
|
del pattern2[1]["RIGHT_ATTRS"]
|
|
matcher.add("FOUNDED", [pattern2])
|
|
with pytest.raises(ValueError):
|
|
pattern2 = copy.deepcopy(pattern)
|
|
del pattern2[1]["LEFT_ID"]
|
|
matcher.add("FOUNDED", [pattern2])
|
|
with pytest.raises(ValueError):
|
|
pattern2 = copy.deepcopy(pattern)
|
|
del pattern2[1]["REL_OP"]
|
|
matcher.add("FOUNDED", [pattern2])
|
|
# invalid operator
|
|
with pytest.raises(ValueError):
|
|
pattern2 = copy.deepcopy(pattern)
|
|
pattern2[1]["REL_OP"] = "!!!"
|
|
matcher.add("FOUNDED", [pattern2])
|
|
# duplicate node name
|
|
with pytest.raises(ValueError):
|
|
pattern2 = copy.deepcopy(pattern)
|
|
pattern2[1]["RIGHT_ID"] = "fox"
|
|
matcher.add("FOUNDED", [pattern2])
|
|
|
|
|
|
def test_dependency_matcher_callback(en_vocab, doc):
|
|
pattern = [
|
|
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "quick"}},
|
|
]
|
|
nomatch_pattern = [
|
|
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "NOMATCH"}},
|
|
]
|
|
|
|
matcher = DependencyMatcher(en_vocab)
|
|
mock = Mock()
|
|
matcher.add("pattern", [pattern], on_match=mock)
|
|
matcher.add("nomatch_pattern", [nomatch_pattern], on_match=mock)
|
|
matches = matcher(doc)
|
|
assert len(matches) == 1
|
|
mock.assert_called_once_with(matcher, doc, 0, matches)
|
|
|
|
# check that matches with and without callback are the same (#4590)
|
|
matcher2 = DependencyMatcher(en_vocab)
|
|
matcher2.add("pattern", [pattern])
|
|
matches2 = matcher2(doc)
|
|
assert matches == matches2
|
|
|
|
|
|
@pytest.mark.parametrize("op,num_matches", [(".", 8), (".*", 20), (";", 8), (";*", 20)])
|
|
def test_dependency_matcher_precedence_ops(en_vocab, op, num_matches):
|
|
# two sentences to test that all matches are within the same sentence
|
|
doc = Doc(
|
|
en_vocab,
|
|
words=["a", "b", "c", "d", "e"] * 2,
|
|
heads=[0, 0, 0, 0, 0, 5, 5, 5, 5, 5],
|
|
deps=["dep"] * 10,
|
|
)
|
|
match_count = 0
|
|
for text in ["a", "b", "c", "d", "e"]:
|
|
pattern = [
|
|
{"RIGHT_ID": "1", "RIGHT_ATTRS": {"ORTH": text}},
|
|
{"LEFT_ID": "1", "REL_OP": op, "RIGHT_ID": "2", "RIGHT_ATTRS": {}},
|
|
]
|
|
matcher = DependencyMatcher(en_vocab)
|
|
matcher.add("A", [pattern])
|
|
matches = matcher(doc)
|
|
match_count += len(matches)
|
|
for match in matches:
|
|
match_id, token_ids = match
|
|
# token_ids[0] op token_ids[1]
|
|
if op == ".":
|
|
assert token_ids[0] == token_ids[1] - 1
|
|
elif op == ";":
|
|
assert token_ids[0] == token_ids[1] + 1
|
|
elif op == ".*":
|
|
assert token_ids[0] < token_ids[1]
|
|
elif op == ";*":
|
|
assert token_ids[0] > token_ids[1]
|
|
# all tokens are within the same sentence
|
|
assert doc[token_ids[0]].sent == doc[token_ids[1]].sent
|
|
assert match_count == num_matches
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"left,right,op,num_matches",
|
|
[
|
|
("fox", "jumped", "<", 1),
|
|
("the", "lazy", "<", 0),
|
|
("jumped", "jumped", "<", 0),
|
|
("fox", "jumped", ">", 0),
|
|
("fox", "lazy", ">", 1),
|
|
("lazy", "lazy", ">", 0),
|
|
("fox", "jumped", "<<", 2),
|
|
("jumped", "fox", "<<", 0),
|
|
("the", "fox", "<<", 2),
|
|
("fox", "jumped", ">>", 0),
|
|
("over", "the", ">>", 1),
|
|
("fox", "the", ">>", 2),
|
|
("fox", "jumped", ".", 1),
|
|
("lazy", "fox", ".", 1),
|
|
("the", "fox", ".", 0),
|
|
("the", "the", ".", 0),
|
|
("fox", "jumped", ";", 0),
|
|
("lazy", "fox", ";", 0),
|
|
("the", "fox", ";", 0),
|
|
("the", "the", ";", 0),
|
|
("quick", "fox", ".*", 2),
|
|
("the", "fox", ".*", 3),
|
|
("the", "the", ".*", 1),
|
|
("fox", "jumped", ";*", 1),
|
|
("quick", "fox", ";*", 0),
|
|
("the", "fox", ";*", 1),
|
|
("the", "the", ";*", 1),
|
|
("quick", "brown", "$+", 1),
|
|
("brown", "quick", "$+", 0),
|
|
("brown", "brown", "$+", 0),
|
|
("quick", "brown", "$-", 0),
|
|
("brown", "quick", "$-", 1),
|
|
("brown", "brown", "$-", 0),
|
|
("the", "brown", "$++", 1),
|
|
("brown", "the", "$++", 0),
|
|
("brown", "brown", "$++", 0),
|
|
("the", "brown", "$--", 0),
|
|
("brown", "the", "$--", 1),
|
|
("brown", "brown", "$--", 0),
|
|
],
|
|
)
|
|
def test_dependency_matcher_ops(en_vocab, doc, left, right, op, num_matches):
|
|
right_id = right
|
|
if left == right:
|
|
right_id = right + "2"
|
|
pattern = [
|
|
{"RIGHT_ID": left, "RIGHT_ATTRS": {"LOWER": left}},
|
|
{
|
|
"LEFT_ID": left,
|
|
"REL_OP": op,
|
|
"RIGHT_ID": right_id,
|
|
"RIGHT_ATTRS": {"LOWER": right},
|
|
},
|
|
]
|
|
|
|
matcher = DependencyMatcher(en_vocab)
|
|
matcher.add("pattern", [pattern])
|
|
matches = matcher(doc)
|
|
assert len(matches) == num_matches
|
|
|
|
|
|
def test_dependency_matcher_long_matches(en_vocab, doc):
|
|
pattern = [
|
|
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"DEP": "amod", "OP": "+"}},
|
|
]
|
|
|
|
matcher = DependencyMatcher(en_vocab)
|
|
with pytest.raises(ValueError):
|
|
matcher.add("pattern", [pattern])
|
|
|
|
|
|
@pytest.mark.usefixtures("clean_underscore")
|
|
def test_dependency_matcher_span_user_data(en_tokenizer):
|
|
doc = en_tokenizer("a b c d e")
|
|
for token in doc:
|
|
token.head = doc[0]
|
|
token.dep_ = "a"
|
|
get_is_c = lambda token: token.text in ("c",)
|
|
Token.set_extension("is_c", default=False)
|
|
doc[2]._.is_c = True
|
|
pattern = [
|
|
{"RIGHT_ID": "c", "RIGHT_ATTRS": {"_": {"is_c": True}}},
|
|
]
|
|
matcher = DependencyMatcher(en_tokenizer.vocab)
|
|
matcher.add("C", [pattern])
|
|
doc_matches = matcher(doc)
|
|
offset = 1
|
|
span_matches = matcher(doc[offset:])
|
|
for doc_match, span_match in zip(sorted(doc_matches), sorted(span_matches)):
|
|
assert doc_match[0] == span_match[0]
|
|
for doc_t_i, span_t_i in zip(doc_match[1], span_match[1]):
|
|
assert doc_t_i == span_t_i + offset
|