spaCy/spacy/tests/matcher/test_dependency_matcher.py

371 lines
11 KiB
Python
Raw Normal View History

import pytest
import pickle
import re
import copy
from mock import Mock
from spacy.matcher import DependencyMatcher
from spacy.tokens import Doc, Token
from ..doc.test_underscore import clean_underscore # noqa: F401
@pytest.fixture
def doc(en_vocab):
2020-09-21 21:43:54 +03:00
words = ["The", "quick", "brown", "fox", "jumped", "over", "the", "lazy", "fox"]
heads = [3, 3, 3, 4, 4, 4, 8, 8, 5]
deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"]
2020-09-21 21:43:54 +03:00
return Doc(en_vocab, words=words, heads=heads, deps=deps)
@pytest.fixture
def patterns(en_vocab):
def is_brown_yellow(text):
return bool(re.compile(r"brown|yellow").match(text))
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
pattern1 = [
{"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "q",
"RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "r",
"RIGHT_ATTRS": {IS_BROWN_YELLOW: True},
},
]
pattern2 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">",
"RIGHT_ID": "fox1",
"RIGHT_ATTRS": {"ORTH": "fox"},
},
{
"LEFT_ID": "jumped",
"REL_OP": ".",
"RIGHT_ID": "over",
"RIGHT_ATTRS": {"ORTH": "over"},
},
]
pattern3 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">",
"RIGHT_ID": "fox",
"RIGHT_ATTRS": {"ORTH": "fox"},
},
{
"LEFT_ID": "fox",
"REL_OP": ">>",
"RIGHT_ID": "r",
"RIGHT_ATTRS": {"ORTH": "brown"},
},
]
pattern4 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">",
"RIGHT_ID": "fox",
"RIGHT_ATTRS": {"ORTH": "fox"},
2020-09-13 11:55:36 +03:00
},
]
pattern5 = [
{"RIGHT_ID": "jumped", "RIGHT_ATTRS": {"ORTH": "jumped"}},
{
"LEFT_ID": "jumped",
"REL_OP": ">>",
"RIGHT_ID": "fox",
"RIGHT_ATTRS": {"ORTH": "fox"},
},
]
return [pattern1, pattern2, pattern3, pattern4, pattern5]
@pytest.fixture
def dependency_matcher(en_vocab, patterns, doc):
matcher = DependencyMatcher(en_vocab)
mock = Mock()
for i in range(1, len(patterns) + 1):
if i == 1:
matcher.add("pattern1", [patterns[0]], on_match=mock)
else:
matcher.add("pattern" + str(i), [patterns[i - 1]])
return matcher
def test_dependency_matcher(dependency_matcher, doc, patterns):
assert len(dependency_matcher) == 5
assert "pattern3" in dependency_matcher
assert dependency_matcher.get("pattern3") == (None, [patterns[2]])
matches = dependency_matcher(doc)
assert len(matches) == 6
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
assert matches[5][1] == [4, 8]
span = doc[0:6]
matches = dependency_matcher(span)
assert len(matches) == 5
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
def test_dependency_matcher_pickle(en_vocab, patterns, doc):
matcher = DependencyMatcher(en_vocab)
for i in range(1, len(patterns) + 1):
matcher.add("pattern" + str(i), [patterns[i - 1]])
matches = matcher(doc)
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
assert matches[5][1] == [4, 8]
b = pickle.dumps(matcher)
matcher_r = pickle.loads(b)
assert len(matcher) == len(matcher_r)
matches = matcher_r(doc)
assert matches[0][1] == [3, 1, 2]
assert matches[1][1] == [4, 3, 5]
assert matches[2][1] == [4, 3, 2]
assert matches[3][1] == [4, 3]
assert matches[4][1] == [4, 3]
assert matches[5][1] == [4, 8]
def test_dependency_matcher_pattern_validation(en_vocab):
pattern = [
{"RIGHT_ID": "fox", "RIGHT_ATTRS": {"ORTH": "fox"}},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "q",
"RIGHT_ATTRS": {"ORTH": "quick", "DEP": "amod"},
},
{
"LEFT_ID": "fox",
"REL_OP": ">",
"RIGHT_ID": "r",
"RIGHT_ATTRS": {"ORTH": "brown"},
},
]
matcher = DependencyMatcher(en_vocab)
# original pattern is valid
matcher.add("FOUNDED", [pattern])
# individual pattern not wrapped in a list
with pytest.raises(ValueError):
matcher.add("FOUNDED", pattern)
# no anchor node
with pytest.raises(ValueError):
matcher.add("FOUNDED", [pattern[1:]])
# required keys missing
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[0]["RIGHT_ID"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["RIGHT_ID"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["RIGHT_ATTRS"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["LEFT_ID"]
matcher.add("FOUNDED", [pattern2])
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
del pattern2[1]["REL_OP"]
matcher.add("FOUNDED", [pattern2])
# invalid operator
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
pattern2[1]["REL_OP"] = "!!!"
matcher.add("FOUNDED", [pattern2])
# duplicate node name
with pytest.raises(ValueError):
pattern2 = copy.deepcopy(pattern)
pattern2[1]["RIGHT_ID"] = "fox"
matcher.add("FOUNDED", [pattern2])
def test_dependency_matcher_callback(en_vocab, doc):
pattern = [
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "quick"}},
]
nomatch_pattern = [
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "NOMATCH"}},
]
matcher = DependencyMatcher(en_vocab)
mock = Mock()
matcher.add("pattern", [pattern], on_match=mock)
matcher.add("nomatch_pattern", [nomatch_pattern], on_match=mock)
matches = matcher(doc)
assert len(matches) == 1
mock.assert_called_once_with(matcher, doc, 0, matches)
# check that matches with and without callback are the same (#4590)
matcher2 = DependencyMatcher(en_vocab)
matcher2.add("pattern", [pattern])
matches2 = matcher2(doc)
assert matches == matches2
2020-09-13 11:55:36 +03:00
@pytest.mark.parametrize("op,num_matches", [(".", 8), (".*", 20), (";", 8), (";*", 20)])
def test_dependency_matcher_precedence_ops(en_vocab, op, num_matches):
# two sentences to test that all matches are within the same sentence
2020-09-21 21:43:54 +03:00
doc = Doc(
en_vocab,
words=["a", "b", "c", "d", "e"] * 2,
2020-09-21 21:43:54 +03:00
heads=[0, 0, 0, 0, 0, 5, 5, 5, 5, 5],
deps=["dep"] * 10,
)
match_count = 0
for text in ["a", "b", "c", "d", "e"]:
pattern = [
{"RIGHT_ID": "1", "RIGHT_ATTRS": {"ORTH": text}},
2020-09-13 11:55:36 +03:00
{"LEFT_ID": "1", "REL_OP": op, "RIGHT_ID": "2", "RIGHT_ATTRS": {}},
]
matcher = DependencyMatcher(en_vocab)
matcher.add("A", [pattern])
matches = matcher(doc)
match_count += len(matches)
for match in matches:
match_id, token_ids = match
# token_ids[0] op token_ids[1]
if op == ".":
assert token_ids[0] == token_ids[1] - 1
elif op == ";":
assert token_ids[0] == token_ids[1] + 1
elif op == ".*":
assert token_ids[0] < token_ids[1]
elif op == ";*":
assert token_ids[0] > token_ids[1]
# all tokens are within the same sentence
assert doc[token_ids[0]].sent == doc[token_ids[1]].sent
assert match_count == num_matches
@pytest.mark.parametrize(
"left,right,op,num_matches",
[
("fox", "jumped", "<", 1),
("the", "lazy", "<", 0),
("jumped", "jumped", "<", 0),
("fox", "jumped", ">", 0),
("fox", "lazy", ">", 1),
("lazy", "lazy", ">", 0),
("fox", "jumped", "<<", 2),
("jumped", "fox", "<<", 0),
("the", "fox", "<<", 2),
("fox", "jumped", ">>", 0),
("over", "the", ">>", 1),
("fox", "the", ">>", 2),
("fox", "jumped", ".", 1),
("lazy", "fox", ".", 1),
("the", "fox", ".", 0),
("the", "the", ".", 0),
("fox", "jumped", ";", 0),
("lazy", "fox", ";", 0),
("the", "fox", ";", 0),
("the", "the", ";", 0),
("quick", "fox", ".*", 2),
("the", "fox", ".*", 3),
("the", "the", ".*", 1),
("fox", "jumped", ";*", 1),
("quick", "fox", ";*", 0),
("the", "fox", ";*", 1),
("the", "the", ";*", 1),
("quick", "brown", "$+", 1),
("brown", "quick", "$+", 0),
("brown", "brown", "$+", 0),
("quick", "brown", "$-", 0),
("brown", "quick", "$-", 1),
("brown", "brown", "$-", 0),
("the", "brown", "$++", 1),
("brown", "the", "$++", 0),
("brown", "brown", "$++", 0),
("the", "brown", "$--", 0),
("brown", "the", "$--", 1),
("brown", "brown", "$--", 0),
],
)
def test_dependency_matcher_ops(en_vocab, doc, left, right, op, num_matches):
right_id = right
if left == right:
right_id = right + "2"
pattern = [
{"RIGHT_ID": left, "RIGHT_ATTRS": {"LOWER": left}},
{
"LEFT_ID": left,
"REL_OP": op,
"RIGHT_ID": right_id,
"RIGHT_ATTRS": {"LOWER": right},
},
]
matcher = DependencyMatcher(en_vocab)
matcher.add("pattern", [pattern])
matches = matcher(doc)
assert len(matches) == num_matches
def test_dependency_matcher_long_matches(en_vocab, doc):
pattern = [
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"DEP": "amod", "OP": "+"}},
]
matcher = DependencyMatcher(en_vocab)
with pytest.raises(ValueError):
matcher.add("pattern", [pattern])
@pytest.mark.usefixtures("clean_underscore")
def test_dependency_matcher_span_user_data(en_tokenizer):
doc = en_tokenizer("a b c d e")
for token in doc:
token.head = doc[0]
token.dep_ = "a"
Token.set_extension("is_c", default=False)
doc[2]._.is_c = True
pattern = [
{"RIGHT_ID": "c", "RIGHT_ATTRS": {"_": {"is_c": True}}},
]
matcher = DependencyMatcher(en_tokenizer.vocab)
matcher.add("C", [pattern])
doc_matches = matcher(doc)
offset = 1
span_matches = matcher(doc[offset:])
for doc_match, span_match in zip(sorted(doc_matches), sorted(span_matches)):
assert doc_match[0] == span_match[0]
for doc_t_i, span_t_i in zip(doc_match[1], span_match[1]):
assert doc_t_i == span_t_i + offset