Fix on_match callback and remove empty patterns (#6312)

For the `DependencyMatcher`:

* Fix on_match callback so that it is called once per matched pattern
* Fix results so that patterns with empty match lists are not returned
This commit is contained in:
Adriane Boyd 2020-11-05 09:16:26 +01:00 committed by GitHub
parent 45c9a68828
commit 31de700b0f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 36 additions and 35 deletions

View File

@ -235,8 +235,8 @@ cdef class DependencyMatcher:
matched_trees = [] matched_trees = []
self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees) self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
if len(matched_trees) > 0:
matched_key_trees.append((key,matched_trees)) matched_key_trees.append((key,matched_trees))
for i, (ent_id, nodes) in enumerate(matched_key_trees): for i, (ent_id, nodes) in enumerate(matched_key_trees):
on_match = self._callbacks.get(ent_id) on_match = self._callbacks.get(ent_id)
if on_match is not None: if on_match is not None:

View File

@ -7,6 +7,7 @@ from mock import Mock
from spacy.matcher import Matcher, DependencyMatcher from spacy.matcher import Matcher, DependencyMatcher
from spacy.tokens import Doc, Token from spacy.tokens import Doc, Token
from ..doc.test_underscore import clean_underscore # noqa: F401 from ..doc.test_underscore import clean_underscore # noqa: F401
from ..util import get_doc
@pytest.fixture @pytest.fixture
@ -301,22 +302,6 @@ def test_matcher_extension_set_membership(en_vocab):
assert len(matches) == 0 assert len(matches) == 0
@pytest.fixture
def text():
return "The quick brown fox jumped over the lazy fox"
@pytest.fixture
def heads():
return [3, 2, 1, 1, 0, -1, 2, 1, -3]
@pytest.fixture
def deps():
return ["det", "amod", "amod", "nsubj", "prep", "pobj", "det", "amod"]
@pytest.fixture
def dependency_matcher(en_vocab): def dependency_matcher(en_vocab):
def is_brown_yellow(text): def is_brown_yellow(text):
return bool(re.compile(r"brown|yellow|over").match(text)) return bool(re.compile(r"brown|yellow|over").match(text))
@ -359,24 +344,40 @@ def dependency_matcher(en_vocab):
}, },
] ]
# pattern that doesn't match
pattern4 = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "NOMATCH"}},
{
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
"PATTERN": {"ORTH": "fox"},
},
{
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"},
"PATTERN": {"ORTH": "brown"},
},
]
matcher = DependencyMatcher(en_vocab) matcher = DependencyMatcher(en_vocab)
matcher.add("pattern1", [pattern1]) on_match = Mock()
matcher.add("pattern2", [pattern2]) matcher.add("pattern1", [pattern1], on_match=on_match)
matcher.add("pattern3", [pattern3]) matcher.add("pattern2", [pattern2], on_match=on_match)
matcher.add("pattern3", [pattern3], on_match=on_match)
matcher.add("pattern4", [pattern4], on_match=on_match)
return matcher assert len(dependency_matcher) == 4
text = "The quick brown fox jumped over the lazy fox"
heads = [3, 2, 1, 1, 0, -1, 2, 1, -3]
deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"]
def test_dependency_matcher_compile(dependency_matcher): doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
assert len(dependency_matcher) == 3 matches = dependency_matcher(doc)
assert len(matches) == 3
# def test_dependency_matcher(dependency_matcher, text, heads, deps): assert matches[0][1] == [[3, 1, 2]]
# doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps) assert matches[1][1] == [[4, 3, 3]]
# matches = dependency_matcher(doc) assert matches[2][1] == [[4, 3, 2]]
# assert matches[0][1] == [[3, 1, 2]] assert on_match.call_count == 3
# assert matches[1][1] == [[4, 3, 3]]
# assert matches[2][1] == [[4, 3, 2]]
def test_matcher_basic_check(en_vocab): def test_matcher_basic_check(en_vocab):