diff --git a/spacy/tests/pattern/parser.py b/spacy/tests/pattern/parser.py index a56bda20a..50dd3ac60 100644 --- a/spacy/tests/pattern/parser.py +++ b/spacy/tests/pattern/parser.py @@ -44,25 +44,33 @@ class TestPatternParser: assert pattern.number_of_nodes() == 2 assert pattern.number_of_edges() == 1 - base_node_id = list(pattern.adjacency.keys())[0] - adj_map = pattern.adjacency[base_node_id] + quick_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'quick'][0] - assert len(adj_map) == 1 - head_node_id = list(adj_map.keys())[0] - dep = adj_map[head_node_id] + fox_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'fox'][0] + + quick_map = pattern.adjacency[quick_id] + fox_map = pattern.adjacency[fox_id] + + assert len(quick_map) == 0 + assert len(fox_map) == 1 + + dep = fox_map[quick_id] assert dep == 'amod' - assert pattern[base_node_id]['word'] == 'fox' - assert pattern[head_node_id]['word'] == 'quick' def test_define_edge_with_regex(self): query = "[word:quick] >/amod|nsubj/ [word:fox]" pattern = PatternParser.parse(query) - base_node_id = list(pattern.adjacency.keys())[0] - adj_map = pattern.adjacency[base_node_id] + quick_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'quick'][0] + + fox_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'fox'][0] + + fox_map = pattern.adjacency[fox_id] + dep = fox_map[quick_id] - assert len(adj_map) == 1 - head_node_id = list(adj_map.keys())[0] - dep = adj_map[head_node_id] assert dep == re.compile(r'amod|nsubj', re.U)