mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	Fix on_match callback and remove empty patterns (#6312)
For the `DependencyMatcher`: * Fix on_match callback so that it is called once per matched pattern * Fix results so that patterns with empty match lists are not returned
This commit is contained in:
		
							parent
							
								
									45c9a68828
								
							
						
					
					
						commit
						31de700b0f
					
				|  | @ -235,12 +235,12 @@ cdef class DependencyMatcher: | |||
| 
 | ||||
|                 matched_trees = [] | ||||
|                 self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees) | ||||
|                 matched_key_trees.append((key,matched_trees)) | ||||
| 
 | ||||
|             for i, (ent_id, nodes) in enumerate(matched_key_trees): | ||||
|                 on_match = self._callbacks.get(ent_id) | ||||
|                 if on_match is not None: | ||||
|                     on_match(self, doc, i, matched_key_trees) | ||||
|                 if len(matched_trees) > 0: | ||||
|                     matched_key_trees.append((key,matched_trees)) | ||||
|         for i, (ent_id, nodes) in enumerate(matched_key_trees): | ||||
|             on_match = self._callbacks.get(ent_id) | ||||
|             if on_match is not None: | ||||
|                 on_match(self, doc, i, matched_key_trees) | ||||
|         return matched_key_trees | ||||
| 
 | ||||
|     def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees): | ||||
|  |  | |||
|  | @ -7,6 +7,7 @@ from mock import Mock | |||
| from spacy.matcher import Matcher, DependencyMatcher | ||||
| from spacy.tokens import Doc, Token | ||||
| from ..doc.test_underscore import clean_underscore  # noqa: F401 | ||||
| from ..util import get_doc | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
|  | @ -301,22 +302,6 @@ def test_matcher_extension_set_membership(en_vocab): | |||
|     assert len(matches) == 0 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def text(): | ||||
|     return "The quick brown fox jumped over the lazy fox" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def heads(): | ||||
|     return [3, 2, 1, 1, 0, -1, 2, 1, -3] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def deps(): | ||||
|     return ["det", "amod", "amod", "nsubj", "prep", "pobj", "det", "amod"] | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def dependency_matcher(en_vocab): | ||||
|     def is_brown_yellow(text): | ||||
|         return bool(re.compile(r"brown|yellow|over").match(text)) | ||||
|  | @ -359,24 +344,40 @@ def dependency_matcher(en_vocab): | |||
|         }, | ||||
|     ] | ||||
| 
 | ||||
|     # pattern that doesn't match | ||||
|     pattern4 = [ | ||||
|         {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "NOMATCH"}}, | ||||
|         { | ||||
|             "SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, | ||||
|             "PATTERN": {"ORTH": "fox"}, | ||||
|         }, | ||||
|         { | ||||
|             "SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"}, | ||||
|             "PATTERN": {"ORTH": "brown"}, | ||||
|         }, | ||||
|     ] | ||||
| 
 | ||||
|     matcher = DependencyMatcher(en_vocab) | ||||
|     matcher.add("pattern1", [pattern1]) | ||||
|     matcher.add("pattern2", [pattern2]) | ||||
|     matcher.add("pattern3", [pattern3]) | ||||
|     on_match = Mock() | ||||
|     matcher.add("pattern1", [pattern1], on_match=on_match) | ||||
|     matcher.add("pattern2", [pattern2], on_match=on_match) | ||||
|     matcher.add("pattern3", [pattern3], on_match=on_match) | ||||
|     matcher.add("pattern4", [pattern4], on_match=on_match) | ||||
| 
 | ||||
|     return matcher | ||||
|     assert len(dependency_matcher) == 4 | ||||
| 
 | ||||
|     text = "The quick brown fox jumped over the lazy fox" | ||||
|     heads = [3, 2, 1, 1, 0, -1, 2, 1, -3] | ||||
|     deps = ["det", "amod", "amod", "nsubj", "ROOT", "prep", "pobj", "det", "amod"] | ||||
| 
 | ||||
| def test_dependency_matcher_compile(dependency_matcher): | ||||
|     assert len(dependency_matcher) == 3 | ||||
|     doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps) | ||||
|     matches = dependency_matcher(doc) | ||||
| 
 | ||||
| 
 | ||||
| # def test_dependency_matcher(dependency_matcher, text, heads, deps): | ||||
| #     doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps) | ||||
| #     matches = dependency_matcher(doc) | ||||
| #     assert matches[0][1] == [[3, 1, 2]] | ||||
| #     assert matches[1][1] == [[4, 3, 3]] | ||||
| #     assert matches[2][1] == [[4, 3, 2]] | ||||
|     assert len(matches) == 3 | ||||
|     assert matches[0][1] == [[3, 1, 2]] | ||||
|     assert matches[1][1] == [[4, 3, 3]] | ||||
|     assert matches[2][1] == [[4, 3, 2]] | ||||
|     assert on_match.call_count == 3 | ||||
| 
 | ||||
| 
 | ||||
| def test_matcher_basic_check(en_vocab): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user