From 31de700b0ff8d537b8e3215a7c7c5df0c86f71f9 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 5 Nov 2020 09:16:26 +0100 Subject: [PATCH] Fix on_match callback and remove empty patterns (#6312) For the `DependencyMatcher`: * Fix on_match callback so that it is called once per matched pattern * Fix results so that patterns with empty match lists are not returned --- spacy/matcher/dependencymatcher.pyx | 12 ++--- spacy/tests/matcher/test_matcher_api.py | 59 +++++++++++++------------ 2 files changed, 36 insertions(+), 35 deletions(-) diff --git a/spacy/matcher/dependencymatcher.pyx b/spacy/matcher/dependencymatcher.pyx index 56d27024d..e93416043 100644 --- a/spacy/matcher/dependencymatcher.pyx +++ b/spacy/matcher/dependencymatcher.pyx @@ -235,12 +235,12 @@ cdef class DependencyMatcher: matched_trees = [] self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees) - matched_key_trees.append((key,matched_trees)) - - for i, (ent_id, nodes) in enumerate(matched_key_trees): - on_match = self._callbacks.get(ent_id) - if on_match is not None: - on_match(self, doc, i, matched_key_trees) + if len(matched_trees) > 0: + matched_key_trees.append((key,matched_trees)) + for i, (ent_id, nodes) in enumerate(matched_key_trees): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matched_key_trees) return matched_key_trees def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees): diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index a23eb392d..1873d3d2b 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -7,6 +7,7 @@ from mock import Mock from spacy.matcher import Matcher, DependencyMatcher from spacy.tokens import Doc, Token from ..doc.test_underscore import clean_underscore # noqa: F401 +from ..util import get_doc @pytest.fixture @@ -301,22 +302,6 @@ def test_matcher_extension_set_membership(en_vocab): assert len(matches) == 0 -@pytest.fixture -def text(): - return "The quick brown fox jumped over the lazy fox" - - -@pytest.fixture -def heads(): - return [3, 2, 1, 1, 0, -1, 2, 1, -3] - - -@pytest.fixture -def deps(): - return ["det", "amod", "amod", "nsubj", "prep", "pobj", "det", "amod"] - - -@pytest.fixture def dependency_matcher(en_vocab): def is_brown_yellow(text): return bool(re.compile(r"brown|yellow|over").match(text)) @@ -359,24 +344,40 @@ def dependency_matcher(en_vocab): }, ] + # pattern that doesn't match + pattern4 = [ + {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "NOMATCH"}}, + { + "SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, + "PATTERN": {"ORTH": "fox"}, + }, + { + "SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"}, + "PATTERN": {"ORTH": "brown"}, + }, + ] + matcher = DependencyMatcher(en_vocab) - matcher.add("pattern1", [pattern1]) - matcher.add("pattern2", [pattern2]) - matcher.add("pattern3", [pattern3]) + on_match = Mock() + matcher.add("pattern1", [pattern1], on_match=on_match) + matcher.add("pattern2", [pattern2], on_match=on_match) + matcher.add("pattern3", [pattern3], on_match=on_match) + matcher.add("pattern4", [pattern4], on_match=on_match) - return matcher + assert len(dependency_matcher) == 4 + text = "The quick brown fox jumped over the lazy fox" + heads = [3, 2, 1, 1, 0, -1, 2, 1, -3] + deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"] -def test_dependency_matcher_compile(dependency_matcher): - assert len(dependency_matcher) == 3 + doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps) + matches = dependency_matcher(doc) - -# def test_dependency_matcher(dependency_matcher, text, heads, deps): -# doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps) -# matches = dependency_matcher(doc) -# assert matches[0][1] == [[3, 1, 2]] -# assert matches[1][1] == [[4, 3, 3]] -# assert matches[2][1] == [[4, 3, 2]] + assert len(matches) == 3 + assert matches[0][1] == [[3, 1, 2]] + assert matches[1][1] == [[4, 3, 3]] + assert matches[2][1] == [[4, 3, 2]] + assert on_match.call_count == 3 def test_matcher_basic_check(en_vocab):