mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Fix on_match callback and remove empty patterns (#6312)
For the `DependencyMatcher`: * Fix on_match callback so that it is called once per matched pattern * Fix results so that patterns with empty match lists are not returned
This commit is contained in:
parent
45c9a68828
commit
31de700b0f
|
@ -235,12 +235,12 @@ cdef class DependencyMatcher:
|
|||
|
||||
matched_trees = []
|
||||
self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
|
||||
matched_key_trees.append((key,matched_trees))
|
||||
|
||||
for i, (ent_id, nodes) in enumerate(matched_key_trees):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matched_key_trees)
|
||||
if len(matched_trees) > 0:
|
||||
matched_key_trees.append((key,matched_trees))
|
||||
for i, (ent_id, nodes) in enumerate(matched_key_trees):
|
||||
on_match = self._callbacks.get(ent_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matched_key_trees)
|
||||
return matched_key_trees
|
||||
|
||||
def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees):
|
||||
|
|
|
@ -7,6 +7,7 @@ from mock import Mock
|
|||
from spacy.matcher import Matcher, DependencyMatcher
|
||||
from spacy.tokens import Doc, Token
|
||||
from ..doc.test_underscore import clean_underscore # noqa: F401
|
||||
from ..util import get_doc
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -301,22 +302,6 @@ def test_matcher_extension_set_membership(en_vocab):
|
|||
assert len(matches) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text():
|
||||
return "The quick brown fox jumped over the lazy fox"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def heads():
|
||||
return [3, 2, 1, 1, 0, -1, 2, 1, -3]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def deps():
|
||||
return ["det", "amod", "amod", "nsubj", "prep", "pobj", "det", "amod"]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dependency_matcher(en_vocab):
|
||||
def is_brown_yellow(text):
|
||||
return bool(re.compile(r"brown|yellow|over").match(text))
|
||||
|
@ -359,24 +344,40 @@ def dependency_matcher(en_vocab):
|
|||
},
|
||||
]
|
||||
|
||||
# pattern that doesn't match
|
||||
pattern4 = [
|
||||
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "NOMATCH"}},
|
||||
{
|
||||
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
|
||||
"PATTERN": {"ORTH": "fox"},
|
||||
},
|
||||
{
|
||||
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"},
|
||||
"PATTERN": {"ORTH": "brown"},
|
||||
},
|
||||
]
|
||||
|
||||
matcher = DependencyMatcher(en_vocab)
|
||||
matcher.add("pattern1", [pattern1])
|
||||
matcher.add("pattern2", [pattern2])
|
||||
matcher.add("pattern3", [pattern3])
|
||||
on_match = Mock()
|
||||
matcher.add("pattern1", [pattern1], on_match=on_match)
|
||||
matcher.add("pattern2", [pattern2], on_match=on_match)
|
||||
matcher.add("pattern3", [pattern3], on_match=on_match)
|
||||
matcher.add("pattern4", [pattern4], on_match=on_match)
|
||||
|
||||
return matcher
|
||||
assert len(dependency_matcher) == 4
|
||||
|
||||
text = "The quick brown fox jumped over the lazy fox"
|
||||
heads = [3, 2, 1, 1, 0, -1, 2, 1, -3]
|
||||
deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"]
|
||||
|
||||
def test_dependency_matcher_compile(dependency_matcher):
|
||||
assert len(dependency_matcher) == 3
|
||||
doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
|
||||
matches = dependency_matcher(doc)
|
||||
|
||||
|
||||
# def test_dependency_matcher(dependency_matcher, text, heads, deps):
|
||||
# doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
|
||||
# matches = dependency_matcher(doc)
|
||||
# assert matches[0][1] == [[3, 1, 2]]
|
||||
# assert matches[1][1] == [[4, 3, 3]]
|
||||
# assert matches[2][1] == [[4, 3, 2]]
|
||||
assert len(matches) == 3
|
||||
assert matches[0][1] == [[3, 1, 2]]
|
||||
assert matches[1][1] == [[4, 3, 3]]
|
||||
assert matches[2][1] == [[4, 3, 2]]
|
||||
assert on_match.call_count == 3
|
||||
|
||||
|
||||
def test_matcher_basic_check(en_vocab):
|
||||
|
|
Loading…
Reference in New Issue
Block a user