mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-15 03:56:23 +03:00
Add basic unit tests for Pattern
This commit is contained in:
parent
1849a110e3
commit
46637369aa
1
spacy/tests/pattern/__init__.py
Normal file
1
spacy/tests/pattern/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# coding: utf-8
|
68
spacy/tests/pattern/parser.py
Normal file
68
spacy/tests/pattern/parser.py
Normal file
|
@ -0,0 +1,68 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
from ...pattern.parser import PatternParser
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatternParser:
|
||||||
|
def test_empty_query(self):
|
||||||
|
assert PatternParser.parse('') is None
|
||||||
|
assert PatternParser.parse(' ') is None
|
||||||
|
|
||||||
|
def test_define_node(self):
|
||||||
|
query = "fox [lemma:fox,word:fox]=alias"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
assert pattern is not None
|
||||||
|
assert pattern.number_of_nodes() == 1
|
||||||
|
assert pattern.number_of_edges() == 0
|
||||||
|
|
||||||
|
assert 'fox' in pattern.nodes
|
||||||
|
|
||||||
|
attrs = pattern['fox']
|
||||||
|
assert attrs.get('lemma') == 'fox'
|
||||||
|
assert attrs.get('word') == 'fox'
|
||||||
|
assert attrs.get('_name') == 'fox'
|
||||||
|
assert attrs.get('_alias') == 'alias'
|
||||||
|
|
||||||
|
for adj_list in pattern.adjacency.values():
|
||||||
|
assert not adj_list
|
||||||
|
|
||||||
|
def test_define_node_with_regex(self):
|
||||||
|
query = "fox [lemma:/fo.*/]"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
attrs = pattern['fox']
|
||||||
|
assert attrs.get('lemma') == re.compile(r'fo.*', re.U)
|
||||||
|
|
||||||
|
def test_define_edge(self):
|
||||||
|
query = "[word:quick] >amod [word:fox]"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
assert pattern is not None
|
||||||
|
assert pattern.number_of_nodes() == 2
|
||||||
|
assert pattern.number_of_edges() == 1
|
||||||
|
|
||||||
|
base_node_id = list(pattern.adjacency.keys())[0]
|
||||||
|
adj_map = pattern.adjacency[base_node_id]
|
||||||
|
|
||||||
|
assert len(adj_map) == 1
|
||||||
|
head_node_id = list(adj_map.keys())[0]
|
||||||
|
dep = adj_map[head_node_id]
|
||||||
|
|
||||||
|
assert dep == 'amod'
|
||||||
|
assert pattern[base_node_id]['word'] == 'fox'
|
||||||
|
assert pattern[head_node_id]['word'] == 'quick'
|
||||||
|
|
||||||
|
def test_define_edge_with_regex(self):
|
||||||
|
query = "[word:quick] >/amod|nsubj/ [word:fox]"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
base_node_id = list(pattern.adjacency.keys())[0]
|
||||||
|
adj_map = pattern.adjacency[base_node_id]
|
||||||
|
|
||||||
|
assert len(adj_map) == 1
|
||||||
|
head_node_id = list(adj_map.keys())[0]
|
||||||
|
dep = adj_map[head_node_id]
|
||||||
|
assert dep == re.compile(r'amod|nsubj', re.U)
|
61
spacy/tests/pattern/pattern.py
Normal file
61
spacy/tests/pattern/pattern.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
from ..util import get_doc
|
||||||
|
from ...pattern.pattern import Tree, DependencyTree
|
||||||
|
from ...pattern.parser import PatternParser
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.addHandler(logging.StreamHandler())
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def doc(en_vocab):
|
||||||
|
words = ['I', "'m", 'going', 'to', 'the', 'zoo', 'next', 'week', '.']
|
||||||
|
doc = get_doc(en_vocab,
|
||||||
|
words=words,
|
||||||
|
deps=['nsubj', 'aux', 'ROOT', 'prep', 'det', 'pobj',
|
||||||
|
'amod', 'npadvmod', 'punct'],
|
||||||
|
heads=[2, 1, 0, -1, 1, -2, 1, -5, -6])
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
class TestTree:
|
||||||
|
def test_is_connected(self):
|
||||||
|
tree = Tree()
|
||||||
|
tree.add_node(1)
|
||||||
|
tree.add_node(2)
|
||||||
|
tree.add_edge(1, 2)
|
||||||
|
|
||||||
|
assert tree.is_connected()
|
||||||
|
|
||||||
|
tree.add_node(3)
|
||||||
|
assert not tree.is_connected()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDependencyTree:
|
||||||
|
def test_from_doc(self, doc):
|
||||||
|
dep_tree = DependencyTree(doc)
|
||||||
|
|
||||||
|
assert len(dep_tree) == len(doc)
|
||||||
|
assert dep_tree.is_connected()
|
||||||
|
assert dep_tree.number_of_edges() == len(doc) - 1
|
||||||
|
|
||||||
|
def test_simple_matching(self, doc):
|
||||||
|
dep_tree = DependencyTree(doc)
|
||||||
|
pattern = PatternParser.parse("""root [word:going]
|
||||||
|
to [word:to]
|
||||||
|
[word:week]=date > root
|
||||||
|
[word:/zoo|park/]=place >pobj to
|
||||||
|
to >prep root
|
||||||
|
""")
|
||||||
|
assert pattern is not None
|
||||||
|
matches = dep_tree.match(pattern)
|
||||||
|
assert len(matches) == 1
|
||||||
|
|
||||||
|
match = matches[0]
|
||||||
|
assert match['place'] == doc[5]
|
||||||
|
assert match['date'] == doc[7]
|
Loading…
Reference in New Issue
Block a user