mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix on_match callback for DependencyMatcher (#6313)
Fix `DependencyMatcher` so that the callback is called only once per match.
This commit is contained in:
parent
2918923541
commit
5d2cb86c34
|
@ -286,10 +286,10 @@ cdef class DependencyMatcher:
|
|||
self.recurse(_tree, id_to_position, _node_operator_map, 0, [], matched_trees)
|
||||
for matched_tree in matched_trees:
|
||||
matched_key_trees.append((key, matched_tree))
|
||||
for i, (match_id, nodes) in enumerate(matched_key_trees):
|
||||
on_match = self._callbacks.get(match_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matched_key_trees)
|
||||
for i, (match_id, nodes) in enumerate(matched_key_trees):
|
||||
on_match = self._callbacks.get(match_id)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matched_key_trees)
|
||||
return matched_key_trees
|
||||
|
||||
def recurse(self, tree, id_to_position, _node_operator_map, int patternLength, visited_nodes, matched_trees):
|
||||
|
|
|
@ -218,11 +218,16 @@ def test_dependency_matcher_callback(en_vocab, doc):
|
|||
pattern = [
|
||||
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "quick"}},
|
||||
]
|
||||
nomatch_pattern = [
|
||||
{"RIGHT_ID": "quick", "RIGHT_ATTRS": {"ORTH": "NOMATCH"}},
|
||||
]
|
||||
|
||||
matcher = DependencyMatcher(en_vocab)
|
||||
mock = Mock()
|
||||
matcher.add("pattern", [pattern], on_match=mock)
|
||||
matcher.add("nomatch_pattern", [nomatch_pattern], on_match=mock)
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 1
|
||||
mock.assert_called_once_with(matcher, doc, 0, matches)
|
||||
|
||||
# check that matches with and without callback are the same (#4590)
|
||||
|
|
Loading…
Reference in New Issue
Block a user