mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add basic unit tests for Pattern
This commit is contained in:
		
							parent
							
								
									1849a110e3
								
							
						
					
					
						commit
						46637369aa
					
				
							
								
								
									
										1
									
								
								spacy/tests/pattern/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								spacy/tests/pattern/__init__.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1 @@
 | 
				
			||||||
 | 
					# coding: utf-8
 | 
				
			||||||
							
								
								
									
										68
									
								
								spacy/tests/pattern/parser.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										68
									
								
								spacy/tests/pattern/parser.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,68 @@
 | 
				
			||||||
 | 
					# coding: utf-8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import re
 | 
				
			||||||
 | 
					from ...pattern.parser import PatternParser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestPatternParser:
 | 
				
			||||||
 | 
					    def test_empty_query(self):
 | 
				
			||||||
 | 
					        assert PatternParser.parse('') is None
 | 
				
			||||||
 | 
					        assert PatternParser.parse('      ') is None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_define_node(self):
 | 
				
			||||||
 | 
					        query = "fox [lemma:fox,word:fox]=alias"
 | 
				
			||||||
 | 
					        pattern = PatternParser.parse(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert pattern is not None
 | 
				
			||||||
 | 
					        assert pattern.number_of_nodes() == 1
 | 
				
			||||||
 | 
					        assert pattern.number_of_edges() == 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert 'fox' in pattern.nodes
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attrs = pattern['fox']
 | 
				
			||||||
 | 
					        assert attrs.get('lemma') == 'fox'
 | 
				
			||||||
 | 
					        assert attrs.get('word') == 'fox'
 | 
				
			||||||
 | 
					        assert attrs.get('_name') == 'fox'
 | 
				
			||||||
 | 
					        assert attrs.get('_alias') == 'alias'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for adj_list in pattern.adjacency.values():
 | 
				
			||||||
 | 
					            assert not adj_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_define_node_with_regex(self):
 | 
				
			||||||
 | 
					        query = "fox [lemma:/fo.*/]"
 | 
				
			||||||
 | 
					        pattern = PatternParser.parse(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        attrs = pattern['fox']
 | 
				
			||||||
 | 
					        assert attrs.get('lemma') == re.compile(r'fo.*', re.U)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_define_edge(self):
 | 
				
			||||||
 | 
					        query = "[word:quick] >amod [word:fox]"
 | 
				
			||||||
 | 
					        pattern = PatternParser.parse(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert pattern is not None
 | 
				
			||||||
 | 
					        assert pattern.number_of_nodes() == 2
 | 
				
			||||||
 | 
					        assert pattern.number_of_edges() == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        base_node_id = list(pattern.adjacency.keys())[0]
 | 
				
			||||||
 | 
					        adj_map = pattern.adjacency[base_node_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert len(adj_map) == 1
 | 
				
			||||||
 | 
					        head_node_id = list(adj_map.keys())[0]
 | 
				
			||||||
 | 
					        dep = adj_map[head_node_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert dep == 'amod'
 | 
				
			||||||
 | 
					        assert pattern[base_node_id]['word'] == 'fox'
 | 
				
			||||||
 | 
					        assert pattern[head_node_id]['word'] == 'quick'
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_define_edge_with_regex(self):
 | 
				
			||||||
 | 
					        query = "[word:quick] >/amod|nsubj/ [word:fox]"
 | 
				
			||||||
 | 
					        pattern = PatternParser.parse(query)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        base_node_id = list(pattern.adjacency.keys())[0]
 | 
				
			||||||
 | 
					        adj_map = pattern.adjacency[base_node_id]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert len(adj_map) == 1
 | 
				
			||||||
 | 
					        head_node_id = list(adj_map.keys())[0]
 | 
				
			||||||
 | 
					        dep = adj_map[head_node_id]
 | 
				
			||||||
 | 
					        assert dep == re.compile(r'amod|nsubj', re.U)
 | 
				
			||||||
							
								
								
									
										61
									
								
								spacy/tests/pattern/pattern.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										61
									
								
								spacy/tests/pattern/pattern.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,61 @@
 | 
				
			||||||
 | 
					# coding: utf-8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from ..util import get_doc
 | 
				
			||||||
 | 
					from ...pattern.pattern import Tree, DependencyTree
 | 
				
			||||||
 | 
					from ...pattern.parser import PatternParser
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import pytest
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import logging
 | 
				
			||||||
 | 
					logger = logging.getLogger()
 | 
				
			||||||
 | 
					logger.addHandler(logging.StreamHandler())
 | 
				
			||||||
 | 
					logger.setLevel(logging.DEBUG)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@pytest.fixture
 | 
				
			||||||
 | 
					def doc(en_vocab):
 | 
				
			||||||
 | 
					    words = ['I', "'m", 'going', 'to', 'the', 'zoo', 'next', 'week', '.']
 | 
				
			||||||
 | 
					    doc = get_doc(en_vocab,
 | 
				
			||||||
 | 
					                  words=words,
 | 
				
			||||||
 | 
					                  deps=['nsubj', 'aux', 'ROOT', 'prep', 'det', 'pobj',
 | 
				
			||||||
 | 
					                        'amod', 'npadvmod', 'punct'],
 | 
				
			||||||
 | 
					                  heads=[2, 1, 0, -1, 1, -2, 1, -5, -6])
 | 
				
			||||||
 | 
					    return doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestTree:
 | 
				
			||||||
 | 
					    def test_is_connected(self):
 | 
				
			||||||
 | 
					        tree = Tree()
 | 
				
			||||||
 | 
					        tree.add_node(1)
 | 
				
			||||||
 | 
					        tree.add_node(2)
 | 
				
			||||||
 | 
					        tree.add_edge(1, 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert tree.is_connected()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tree.add_node(3)
 | 
				
			||||||
 | 
					        assert not tree.is_connected()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class TestDependencyTree:
 | 
				
			||||||
 | 
					    def test_from_doc(self, doc):
 | 
				
			||||||
 | 
					        dep_tree = DependencyTree(doc)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        assert len(dep_tree) == len(doc)
 | 
				
			||||||
 | 
					        assert dep_tree.is_connected()
 | 
				
			||||||
 | 
					        assert dep_tree.number_of_edges() == len(doc) - 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_simple_matching(self, doc):
 | 
				
			||||||
 | 
					        dep_tree = DependencyTree(doc)
 | 
				
			||||||
 | 
					        pattern = PatternParser.parse("""root [word:going]
 | 
				
			||||||
 | 
					                                         to [word:to]
 | 
				
			||||||
 | 
					                                         [word:week]=date > root
 | 
				
			||||||
 | 
					                                         [word:/zoo|park/]=place >pobj to
 | 
				
			||||||
 | 
					                                         to >prep root
 | 
				
			||||||
 | 
					                                         """)
 | 
				
			||||||
 | 
					        assert pattern is not None
 | 
				
			||||||
 | 
					        matches = dep_tree.match(pattern)
 | 
				
			||||||
 | 
					        assert len(matches) == 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        match = matches[0]
 | 
				
			||||||
 | 
					        assert match['place'] == doc[5]
 | 
				
			||||||
 | 
					        assert match['date'] == doc[7]
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user