mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Merge pull request #1120 from raphael0202/pattern
Implementation of dependency pattern-matching algorithm
This commit is contained in:
commit
8775efbfdf
|
@ -16,6 +16,10 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
import copyreg as copy_reg
|
import copyreg as copy_reg
|
||||||
|
|
||||||
|
try:
|
||||||
|
import Queue as queue
|
||||||
|
except ImportError:
|
||||||
|
import queue
|
||||||
|
|
||||||
is_python2 = six.PY2
|
is_python2 = six.PY2
|
||||||
is_python3 = six.PY3
|
is_python3 = six.PY3
|
||||||
|
@ -32,6 +36,7 @@ if is_python2:
|
||||||
basestring_ = basestring
|
basestring_ = basestring
|
||||||
input_ = raw_input
|
input_ = raw_input
|
||||||
json_dumps = lambda data: ujson.dumps(data, indent=2).decode('utf8')
|
json_dumps = lambda data: ujson.dumps(data, indent=2).decode('utf8')
|
||||||
|
intern = intern
|
||||||
|
|
||||||
elif is_python3:
|
elif is_python3:
|
||||||
bytes_ = bytes
|
bytes_ = bytes
|
||||||
|
@ -39,6 +44,7 @@ elif is_python3:
|
||||||
basestring_ = str
|
basestring_ = str
|
||||||
input_ = input
|
input_ = input
|
||||||
json_dumps = lambda data: ujson.dumps(data, indent=2)
|
json_dumps = lambda data: ujson.dumps(data, indent=2)
|
||||||
|
intern = sys.intern
|
||||||
|
|
||||||
|
|
||||||
def symlink_to(orig, dest):
|
def symlink_to(orig, dest):
|
||||||
|
|
4
spacy/pattern/__init__.py
Normal file
4
spacy/pattern/__init__.py
Normal file
|
@ -0,0 +1,4 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
from .pattern import DependencyTree
|
||||||
|
from .parser import PatternParser
|
377
spacy/pattern/parser.py
Normal file
377
spacy/pattern/parser.py
Normal file
|
@ -0,0 +1,377 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
from spacy.compat import intern, queue
|
||||||
|
from spacy.strings import hash_string
|
||||||
|
from operator import itemgetter
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
|
||||||
|
from .pattern import DependencyPattern
|
||||||
|
|
||||||
|
TOKEN_INITIAL = intern('initial')
|
||||||
|
|
||||||
|
|
||||||
|
class PatternParser(object):
|
||||||
|
"""Compile a Pattern query into a :class:`Pattern`, that can be used to
|
||||||
|
match :class:`DependencyTree`s."""
|
||||||
|
whitespace_re = re.compile(r'\s+', re.U)
|
||||||
|
newline_re = re.compile(r'(\r\n|\r|\n)')
|
||||||
|
name_re = re.compile(r'\w+', re.U)
|
||||||
|
|
||||||
|
TOKEN_BLOCK_BEGIN = '['
|
||||||
|
TOKEN_BLOCK_END = ']'
|
||||||
|
EDGE_BLOCK_BEGIN = '>'
|
||||||
|
WHITESPACE = ' '
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse(cls, query):
|
||||||
|
"""Parse the given `query`, and compile it into a :class:`Pattern`."""
|
||||||
|
pattern = DependencyPattern()
|
||||||
|
|
||||||
|
for lineno, token_stream in enumerate(cls.tokenize(query)):
|
||||||
|
try:
|
||||||
|
cls._parse_line(token_stream, pattern, lineno+1)
|
||||||
|
except StopIteration:
|
||||||
|
raise SyntaxError("A token is missing, please check your "
|
||||||
|
"query.")
|
||||||
|
|
||||||
|
if not pattern.nodes:
|
||||||
|
return
|
||||||
|
|
||||||
|
cls.check_pattern(pattern)
|
||||||
|
return pattern
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def check_pattern(pattern):
|
||||||
|
if not pattern.is_connected():
|
||||||
|
raise ValueError("The pattern tree must be a fully connected "
|
||||||
|
"graph.")
|
||||||
|
|
||||||
|
if pattern.root_node is None:
|
||||||
|
raise ValueError("The root node of the tree could not be found.")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _parse_line(cls, stream, pattern, lineno):
|
||||||
|
while not stream.closed:
|
||||||
|
token = stream.current
|
||||||
|
|
||||||
|
if token.type == 'name':
|
||||||
|
next_token = stream.look()
|
||||||
|
|
||||||
|
if next_token.type == 'node':
|
||||||
|
cls.parse_node_def(stream, pattern)
|
||||||
|
|
||||||
|
elif next_token.type == 'edge':
|
||||||
|
cls.parse_edge_def(stream, pattern)
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise SyntaxError("line %d: A 'node' or 'edge' token must "
|
||||||
|
"follow a 'name' token." % lineno)
|
||||||
|
|
||||||
|
elif token.type == 'node':
|
||||||
|
next_token = stream.look()
|
||||||
|
|
||||||
|
if next_token.type == 'edge':
|
||||||
|
cls.parse_edge_def(stream, pattern)
|
||||||
|
else:
|
||||||
|
raise SyntaxError("line %d: an 'edge' token is "
|
||||||
|
"expected." % lineno)
|
||||||
|
|
||||||
|
if not stream.closed:
|
||||||
|
next(stream)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_node_def(cls, stream, pattern):
|
||||||
|
name_token = stream.current
|
||||||
|
next(stream)
|
||||||
|
node_token = stream.current
|
||||||
|
cls.add_node(node_token, pattern, name_token)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def add_node(cls, node_token, pattern, name_token=None):
|
||||||
|
token_name = None
|
||||||
|
if name_token is not None:
|
||||||
|
token_id = name_token.value
|
||||||
|
token_name = name_token.value
|
||||||
|
else:
|
||||||
|
token_id = node_token.hash()
|
||||||
|
|
||||||
|
if token_id in pattern.nodes:
|
||||||
|
raise SyntaxError("Token with ID '{}' already registered.".format(
|
||||||
|
token_id))
|
||||||
|
|
||||||
|
token_attr = cls.parse_node_attributes(node_token.value)
|
||||||
|
token_attr['_name'] = token_name
|
||||||
|
pattern.add_node(token_id, token_attr)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_edge_def(cls, stream, pattern):
|
||||||
|
token = stream.current
|
||||||
|
|
||||||
|
if token.type == 'name':
|
||||||
|
token_id = token.value
|
||||||
|
if token_id not in pattern.nodes:
|
||||||
|
raise SyntaxError("Token '{}' with ID '{}' is not "
|
||||||
|
"defined.".format(token, token_id))
|
||||||
|
|
||||||
|
elif token.type == 'node':
|
||||||
|
token_id = token.hash()
|
||||||
|
cls.add_node(token, pattern)
|
||||||
|
|
||||||
|
next(stream)
|
||||||
|
edge_attr = cls.parse_edge_attributes(stream.current.value)
|
||||||
|
next(stream)
|
||||||
|
|
||||||
|
head_token = stream.current
|
||||||
|
if head_token.type == 'name':
|
||||||
|
head_token_id = head_token.value
|
||||||
|
if head_token_id not in pattern.nodes:
|
||||||
|
raise SyntaxError("Token '{}' with ID '{}' is not "
|
||||||
|
"defined.".format(head_token, head_token_id))
|
||||||
|
elif head_token.type == 'node':
|
||||||
|
head_token_id = head_token.hash()
|
||||||
|
cls.add_node(head_token, pattern)
|
||||||
|
else:
|
||||||
|
raise SyntaxError("A 'node' or 'name' token was expected.")
|
||||||
|
|
||||||
|
# inverse the dependency to have an actual tree
|
||||||
|
pattern.add_edge(head_token_id, token_id, edge_attr)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_node_attributes(cls, string):
|
||||||
|
string = string[1:] # remove the trailing '['
|
||||||
|
end_delimiter_idx = string.find(']')
|
||||||
|
|
||||||
|
attr_str = string[:end_delimiter_idx]
|
||||||
|
attr = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
attr = json.loads(attr_str)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
for pair in attr_str.split(","):
|
||||||
|
key, value = pair.split(':')
|
||||||
|
attr[key] = value
|
||||||
|
|
||||||
|
for key, value in attr.items():
|
||||||
|
attr[key] = cls.compile_expression(value)
|
||||||
|
|
||||||
|
alias = string[end_delimiter_idx+2:]
|
||||||
|
|
||||||
|
if alias:
|
||||||
|
attr['_alias'] = alias
|
||||||
|
|
||||||
|
return attr
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def parse_edge_attributes(cls, string):
|
||||||
|
string = string[1:] # remove the trailing '>'
|
||||||
|
|
||||||
|
if not string:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return cls.compile_expression(string)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def compile_expression(expr):
|
||||||
|
if expr.startswith('/') and expr.endswith('/'):
|
||||||
|
string = expr[1:-1]
|
||||||
|
return re.compile(string, re.U)
|
||||||
|
|
||||||
|
return expr
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tokenize(cls, text):
|
||||||
|
lines = text.splitlines()
|
||||||
|
|
||||||
|
for lineno, line in enumerate(lines):
|
||||||
|
yield TokenStream(cls._tokenize_line(line, lineno+1))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _tokenize_line(cls, line, lineno):
|
||||||
|
reader = Reader(line)
|
||||||
|
|
||||||
|
while reader.remaining():
|
||||||
|
char = reader.next()
|
||||||
|
|
||||||
|
if char == cls.TOKEN_BLOCK_BEGIN:
|
||||||
|
token = 'node'
|
||||||
|
idx = reader.find(cls.TOKEN_BLOCK_END)
|
||||||
|
|
||||||
|
if idx == -1:
|
||||||
|
raise SyntaxError("A token block end ']' was expected.")
|
||||||
|
|
||||||
|
idx += 1
|
||||||
|
if len(reader) > idx and reader[idx] == '=':
|
||||||
|
# The node has a name
|
||||||
|
idx = reader.find(cls.WHITESPACE, start=idx)
|
||||||
|
|
||||||
|
if idx == -1:
|
||||||
|
idx = reader.remaining()
|
||||||
|
|
||||||
|
elif char == cls.EDGE_BLOCK_BEGIN:
|
||||||
|
token = 'edge'
|
||||||
|
idx = reader.find(cls.WHITESPACE)
|
||||||
|
|
||||||
|
elif cls.name_re.match(char):
|
||||||
|
token = 'name'
|
||||||
|
idx = reader.find(cls.WHITESPACE)
|
||||||
|
|
||||||
|
if idx == -1:
|
||||||
|
whole_name_match = cls.name_re.match(str(reader))
|
||||||
|
idx = whole_name_match.end()
|
||||||
|
|
||||||
|
elif cls.newline_re.match(char) or cls.whitespace_re.match(char):
|
||||||
|
# skip the whitespace
|
||||||
|
reader.consume()
|
||||||
|
continue
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise SyntaxError("Unrecognized token BEGIN char: '{"
|
||||||
|
"}'".format(char))
|
||||||
|
|
||||||
|
if idx == -1:
|
||||||
|
raise SyntaxError("Ending character of token '{}' not "
|
||||||
|
"found.".format(token))
|
||||||
|
value = reader.consume(idx)
|
||||||
|
|
||||||
|
yield Token(lineno, token, value)
|
||||||
|
|
||||||
|
|
||||||
|
class Reader(object):
|
||||||
|
"""A class used by the :class:`PatternParser` to tokenize the `text`."""
|
||||||
|
__slots__ = ('text', 'pos')
|
||||||
|
|
||||||
|
def __init__(self, text):
|
||||||
|
self.text = text
|
||||||
|
self.pos = 0
|
||||||
|
|
||||||
|
def find(self, needle, start=0, end=None):
|
||||||
|
pos = self.pos
|
||||||
|
start += pos
|
||||||
|
if end is None:
|
||||||
|
index = self.text.find(needle, start)
|
||||||
|
else:
|
||||||
|
end += pos
|
||||||
|
index = self.text.find(needle, start, end)
|
||||||
|
if index != -1:
|
||||||
|
index -= pos
|
||||||
|
return index
|
||||||
|
|
||||||
|
def consume(self, count=1):
|
||||||
|
new_pos = self.pos + count
|
||||||
|
s = self.text[self.pos:new_pos]
|
||||||
|
self.pos = new_pos
|
||||||
|
return s
|
||||||
|
|
||||||
|
def next(self):
|
||||||
|
return self.text[self.pos:self.pos + 1]
|
||||||
|
|
||||||
|
def remaining(self):
|
||||||
|
return len(self.text) - self.pos
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.remaining()
|
||||||
|
|
||||||
|
def __getitem__(self, key):
|
||||||
|
if key < 0:
|
||||||
|
return self.text[key]
|
||||||
|
else:
|
||||||
|
return self.text[self.pos + key]
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.text[self.pos:]
|
||||||
|
|
||||||
|
|
||||||
|
# The following classes were copied from Jinja2, a BSD-licensed project,
|
||||||
|
# and slightly modified: Token, TokenStreamIterator, TokenStream.
|
||||||
|
|
||||||
|
class Token(tuple):
|
||||||
|
"""Token class."""
|
||||||
|
__slots__ = ()
|
||||||
|
lineno, type, value = (property(itemgetter(x)) for x in range(3))
|
||||||
|
|
||||||
|
def __new__(cls, lineno, type, value):
|
||||||
|
return tuple.__new__(cls, (lineno, intern(str(type)), value))
|
||||||
|
|
||||||
|
def hash(self):
|
||||||
|
string = self.value
|
||||||
|
return hash_string(string)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return 'Token(%r, %r, %r)' % (
|
||||||
|
self.lineno,
|
||||||
|
self.type,
|
||||||
|
self.value)
|
||||||
|
|
||||||
|
|
||||||
|
class TokenStreamIterator(object):
|
||||||
|
"""The iterator for tokenstreams. Iterate over the stream until the
|
||||||
|
stream is empty.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, stream):
|
||||||
|
self.stream = stream
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
token = self.stream.current
|
||||||
|
try:
|
||||||
|
next(self.stream)
|
||||||
|
except StopIteration:
|
||||||
|
self.stream.close()
|
||||||
|
raise StopIteration()
|
||||||
|
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
|
class TokenStream(object):
|
||||||
|
"""A token stream is an iterable that yields :class:`Token`s. The
|
||||||
|
current active token is stored as :attr:`current`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, generator):
|
||||||
|
self._iter = iter(generator)
|
||||||
|
self._pushed = queue.deque()
|
||||||
|
self.closed = False
|
||||||
|
self.current = Token(1, TOKEN_INITIAL, '')
|
||||||
|
next(self)
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return TokenStreamIterator(self)
|
||||||
|
|
||||||
|
def __bool__(self):
|
||||||
|
return bool(self._pushed)
|
||||||
|
__nonzero__ = __bool__ # py2
|
||||||
|
|
||||||
|
def push(self, token):
|
||||||
|
"""Push a token back to the stream."""
|
||||||
|
self._pushed.append(token)
|
||||||
|
|
||||||
|
def look(self):
|
||||||
|
"""Look at the next token."""
|
||||||
|
old_token = next(self)
|
||||||
|
result = self.current
|
||||||
|
self.push(result)
|
||||||
|
self.current = old_token
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
"""Go one token ahead and return the old one."""
|
||||||
|
rv = self.current
|
||||||
|
if self._pushed:
|
||||||
|
self.current = self._pushed.popleft()
|
||||||
|
else:
|
||||||
|
if self.closed:
|
||||||
|
raise StopIteration("No token left.")
|
||||||
|
try:
|
||||||
|
self.current = next(self._iter)
|
||||||
|
except StopIteration:
|
||||||
|
self.close()
|
||||||
|
return rv
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""Close the stream."""
|
||||||
|
self._iter = None
|
||||||
|
self.closed = True
|
318
spacy/pattern/pattern.py
Normal file
318
spacy/pattern/pattern.py
Normal file
|
@ -0,0 +1,318 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class Tree(object):
|
||||||
|
def __init__(self):
|
||||||
|
self.adjacency = defaultdict(dict)
|
||||||
|
self.nodes = {}
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.nodes[item]
|
||||||
|
|
||||||
|
def add_node(self, node, attr_dict=None):
|
||||||
|
attr_dict = attr_dict or {}
|
||||||
|
self.nodes[node] = attr_dict
|
||||||
|
|
||||||
|
def add_edge(self, u, v, dep=None):
|
||||||
|
if u not in self.nodes or v not in self.nodes:
|
||||||
|
raise ValueError("Each node must be defined before adding an edge.")
|
||||||
|
|
||||||
|
self.adjacency[u][v] = dep
|
||||||
|
|
||||||
|
def number_of_nodes(self):
|
||||||
|
return len(self)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self.nodes)
|
||||||
|
|
||||||
|
def number_of_edges(self):
|
||||||
|
return sum(len(adj_dict) for adj_dict in self.adjacency.values())
|
||||||
|
|
||||||
|
def edges_iter(self, origin=None, data=True):
|
||||||
|
nbunch = (self.adjacency.items() if origin is None
|
||||||
|
else [(origin, self.adjacency[origin])])
|
||||||
|
|
||||||
|
for u, nodes in nbunch:
|
||||||
|
for v, dep in nodes.items():
|
||||||
|
if data:
|
||||||
|
yield (u, v, dep)
|
||||||
|
else:
|
||||||
|
yield (u, v)
|
||||||
|
|
||||||
|
def nodes_iter(self):
|
||||||
|
for node in self.nodes.keys():
|
||||||
|
yield node
|
||||||
|
|
||||||
|
def is_connected(self):
|
||||||
|
if len(self) == 0:
|
||||||
|
raise ValueError('Connectivity is undefined for the null graph.')
|
||||||
|
return len(set(self._plain_bfs(next(self.nodes_iter()),
|
||||||
|
undirected=True))) == len(self)
|
||||||
|
|
||||||
|
def _plain_bfs(self, source, undirected=False):
|
||||||
|
"""A fast BFS node generator.
|
||||||
|
:param: source: the source node
|
||||||
|
"""
|
||||||
|
seen = set()
|
||||||
|
next_level = {source}
|
||||||
|
while next_level:
|
||||||
|
this_level = next_level
|
||||||
|
next_level = set()
|
||||||
|
for v in this_level:
|
||||||
|
if v not in seen:
|
||||||
|
yield v
|
||||||
|
seen.add(v)
|
||||||
|
next_level.update(self.adjacency[v].keys())
|
||||||
|
|
||||||
|
if undirected:
|
||||||
|
for n, adj in self.adjacency.items():
|
||||||
|
if v in adj.keys():
|
||||||
|
next_level.add(n)
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyPattern(Tree):
|
||||||
|
@property
|
||||||
|
def root_node(self):
|
||||||
|
if self.number_of_nodes() == 1:
|
||||||
|
# if the graph has a single node, it is the root
|
||||||
|
return next(iter(self.nodes.keys()))
|
||||||
|
|
||||||
|
if not self.is_connected():
|
||||||
|
return None
|
||||||
|
|
||||||
|
in_node = set()
|
||||||
|
out_node = set()
|
||||||
|
for u, v in self.edges_iter(data=False):
|
||||||
|
in_node.add(v)
|
||||||
|
out_node.add(u)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return list(out_node.difference(in_node))[0]
|
||||||
|
except IndexError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class DependencyTree(Tree):
|
||||||
|
def __init__(self, doc):
|
||||||
|
super(DependencyTree, self).__init__()
|
||||||
|
|
||||||
|
for token in doc:
|
||||||
|
self.nodes[token.i] = token
|
||||||
|
|
||||||
|
if token.head.i != token.i:
|
||||||
|
# inverse the dependency to have an actual tree
|
||||||
|
self.adjacency[token.head.i][token.i] = token.dep_
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.nodes[item]
|
||||||
|
|
||||||
|
def match_nodes(self, attr_dict, **kwargs):
|
||||||
|
results = []
|
||||||
|
for token_idx, token in self.nodes.items():
|
||||||
|
if match_token(token, attr_dict, **kwargs):
|
||||||
|
results.append(token_idx)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def match(self, pattern):
|
||||||
|
"""Return a list of matches between the given
|
||||||
|
:class:`DependencyPattern` and `self` if any, or None.
|
||||||
|
|
||||||
|
:param pattern: a :class:`DependencyPattern`
|
||||||
|
"""
|
||||||
|
pattern_root_node = pattern.root_node
|
||||||
|
pattern_root_node_attr = pattern[pattern_root_node]
|
||||||
|
dep_root_nodes = self.match_nodes(pattern_root_node_attr)
|
||||||
|
|
||||||
|
if not dep_root_nodes:
|
||||||
|
logger.debug("No node matches the pattern root "
|
||||||
|
"'{}'".format(pattern_root_node_attr))
|
||||||
|
|
||||||
|
matches = []
|
||||||
|
for candidate_root_node in dep_root_nodes:
|
||||||
|
match_list = subtree_in_graph(candidate_root_node, self,
|
||||||
|
pattern_root_node, pattern)
|
||||||
|
for mapping in match_list:
|
||||||
|
match = PatternMatch(mapping, pattern, self)
|
||||||
|
matches.append(match)
|
||||||
|
|
||||||
|
return matches
|
||||||
|
|
||||||
|
|
||||||
|
class PatternMatch(object):
|
||||||
|
def __init__(self, mapping, pattern, tree):
|
||||||
|
for pattern_node_id, tree_node_id in mapping.items():
|
||||||
|
mapping[pattern_node_id] = tree[tree_node_id]
|
||||||
|
self.mapping = mapping
|
||||||
|
self.pattern = pattern
|
||||||
|
self.tree = tree
|
||||||
|
|
||||||
|
self.alias_map = {}
|
||||||
|
for pattern_node_id in self.mapping:
|
||||||
|
pattern_node = self.pattern[pattern_node_id]
|
||||||
|
|
||||||
|
alias = pattern_node.get('_alias')
|
||||||
|
if alias:
|
||||||
|
self.alias_map[alias] = self.mapping[pattern_node_id]
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "<Pattern Match: {} node>".format(len(self.mapping))
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
return self.alias_map[item]
|
||||||
|
|
||||||
|
|
||||||
|
def subtree_in_graph(dep_tree_node, dep_tree, pattern_node, pattern):
|
||||||
|
"""Return a list of matches of `pattern` as a subtree of `dep_tree`.
|
||||||
|
:param dep_tree_node: the token (identified by its index) to start from
|
||||||
|
(int)
|
||||||
|
:param dep_tree: a :class:`DependencyTree`
|
||||||
|
:param pattern_node: the pattern node to start from
|
||||||
|
:param pattern: a :class:`DependencyPattern`
|
||||||
|
:return: found matches (list)
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
association_dict = {pattern_node: dep_tree_node}
|
||||||
|
_subtree_in_graph(dep_tree_node, dep_tree, pattern_node,
|
||||||
|
pattern, results=results,
|
||||||
|
association_dict=association_dict)
|
||||||
|
results = results or []
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def _subtree_in_graph(dep_tree_node, dep_tree, pattern_node, pattern,
|
||||||
|
association_dict=None, results=None):
|
||||||
|
token = dep_tree[dep_tree_node]
|
||||||
|
logger.debug("Starting from token '{}'".format(token.orth_))
|
||||||
|
|
||||||
|
adjacent_edges = list(pattern.edges_iter(origin=pattern_node))
|
||||||
|
if adjacent_edges:
|
||||||
|
for (_, adjacent_pattern_node,
|
||||||
|
dep) in adjacent_edges:
|
||||||
|
adjacent_pattern_node_attr = pattern[adjacent_pattern_node]
|
||||||
|
logger.debug("Exploring relation {} -[{}]-> {} from "
|
||||||
|
"pattern".format(pattern[pattern_node],
|
||||||
|
dep,
|
||||||
|
adjacent_pattern_node_attr))
|
||||||
|
|
||||||
|
adjacent_nodes = find_adjacent_nodes(dep_tree,
|
||||||
|
dep_tree_node,
|
||||||
|
dep,
|
||||||
|
adjacent_pattern_node_attr)
|
||||||
|
|
||||||
|
if not adjacent_nodes:
|
||||||
|
logger.debug("No adjacent nodes in dep_tree satisfying these "
|
||||||
|
"conditions.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
for adjacent_node in adjacent_nodes:
|
||||||
|
logger.debug("Found adjacent node '{}' in "
|
||||||
|
"dep_tree".format(dep_tree[adjacent_node].orth_))
|
||||||
|
association_dict[adjacent_pattern_node] = adjacent_node
|
||||||
|
recursive_return = _subtree_in_graph(adjacent_node,
|
||||||
|
dep_tree,
|
||||||
|
adjacent_pattern_node,
|
||||||
|
pattern,
|
||||||
|
association_dict,
|
||||||
|
results=results)
|
||||||
|
|
||||||
|
if recursive_return is None:
|
||||||
|
# No Match
|
||||||
|
return None
|
||||||
|
|
||||||
|
association_dict, results = recursive_return
|
||||||
|
|
||||||
|
else:
|
||||||
|
if len(association_dict) == pattern.number_of_nodes():
|
||||||
|
logger.debug("Add to results: {}".format(association_dict))
|
||||||
|
results.append(dict(association_dict))
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug("{} nodes in subgraph, only {} "
|
||||||
|
"mapped".format(pattern.number_of_nodes(),
|
||||||
|
len(association_dict)))
|
||||||
|
|
||||||
|
logger.debug("Return intermediate: {}".format(association_dict))
|
||||||
|
return association_dict, results
|
||||||
|
|
||||||
|
|
||||||
|
def find_adjacent_nodes(dep_tree, node, target_dep, node_attributes):
|
||||||
|
"""Find nodes adjacent to ``node`` that fulfill specified attributes
|
||||||
|
values on edge and node.
|
||||||
|
|
||||||
|
:param dep_tree: a :class:`DependencyTree`
|
||||||
|
:param node: initial node to search from
|
||||||
|
:param target_dep: edge attributes that must be fulfilled (pair-value)
|
||||||
|
:type target_dep: dict
|
||||||
|
:param node_attributes: node attributes that must be fulfilled (pair-value)
|
||||||
|
:type node_attributes: dict
|
||||||
|
:return: adjacent nodes that fulfill the given criteria (list)
|
||||||
|
"""
|
||||||
|
results = []
|
||||||
|
for _, adj_node, adj_dep in dep_tree.edges_iter(origin=node):
|
||||||
|
adj_token = dep_tree[adj_node]
|
||||||
|
if (match_edge(adj_dep, target_dep)
|
||||||
|
and match_token(adj_token, node_attributes)):
|
||||||
|
results.append(adj_node)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def match_edge(token_dep, target_dep):
|
||||||
|
if target_dep is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if hasattr(target_dep, 'match'):
|
||||||
|
return target_dep.match(token_dep) is not None
|
||||||
|
|
||||||
|
if token_dep == target_dep:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def match_token(token,
|
||||||
|
target_attributes,
|
||||||
|
ignore_special_key=True,
|
||||||
|
lower=True):
|
||||||
|
bind_map = {
|
||||||
|
'word': lambda t: t.orth_,
|
||||||
|
'lemma': lambda t: t.lemma_,
|
||||||
|
'ent': lambda t: t.ent_type_,
|
||||||
|
}
|
||||||
|
|
||||||
|
for target_key, target_value in target_attributes.items():
|
||||||
|
is_special_key = target_key[0] == '_'
|
||||||
|
|
||||||
|
if ignore_special_key and is_special_key:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if lower and hasattr(target_value, 'lower'):
|
||||||
|
target_value = target_value.lower()
|
||||||
|
|
||||||
|
if target_key in bind_map:
|
||||||
|
token_attr = bind_map[target_key](token)
|
||||||
|
|
||||||
|
if lower:
|
||||||
|
token_attr = token_attr.lower()
|
||||||
|
|
||||||
|
if hasattr(target_value, 'match'): # if it is a compiled regex
|
||||||
|
if target_value.match(token_attr) is None:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if not token_attr == target_value:
|
||||||
|
break
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError("Unknown key: '{}'".format(target_key))
|
||||||
|
|
||||||
|
else: # the loop was not broken
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
1
spacy/tests/pattern/__init__.py
Normal file
1
spacy/tests/pattern/__init__.py
Normal file
|
@ -0,0 +1 @@
|
||||||
|
# coding: utf-8
|
76
spacy/tests/pattern/parser.py
Normal file
76
spacy/tests/pattern/parser.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
|
||||||
|
import re
|
||||||
|
from ...pattern.parser import PatternParser
|
||||||
|
|
||||||
|
|
||||||
|
class TestPatternParser:
|
||||||
|
def test_empty_query(self):
|
||||||
|
assert PatternParser.parse('') is None
|
||||||
|
assert PatternParser.parse(' ') is None
|
||||||
|
|
||||||
|
def test_define_node(self):
|
||||||
|
query = "fox [lemma:fox,word:fox]=alias"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
assert pattern is not None
|
||||||
|
assert pattern.number_of_nodes() == 1
|
||||||
|
assert pattern.number_of_edges() == 0
|
||||||
|
|
||||||
|
assert 'fox' in pattern.nodes
|
||||||
|
|
||||||
|
attrs = pattern['fox']
|
||||||
|
assert attrs.get('lemma') == 'fox'
|
||||||
|
assert attrs.get('word') == 'fox'
|
||||||
|
assert attrs.get('_name') == 'fox'
|
||||||
|
assert attrs.get('_alias') == 'alias'
|
||||||
|
|
||||||
|
for adj_list in pattern.adjacency.values():
|
||||||
|
assert not adj_list
|
||||||
|
|
||||||
|
def test_define_node_with_regex(self):
|
||||||
|
query = "fox [lemma:/fo.*/]"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
attrs = pattern['fox']
|
||||||
|
assert attrs.get('lemma') == re.compile(r'fo.*', re.U)
|
||||||
|
|
||||||
|
def test_define_edge(self):
|
||||||
|
query = "[word:quick] >amod [word:fox]"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
assert pattern is not None
|
||||||
|
assert pattern.number_of_nodes() == 2
|
||||||
|
assert pattern.number_of_edges() == 1
|
||||||
|
|
||||||
|
quick_id = [node_id for node_id, node_attr in pattern.nodes.items()
|
||||||
|
if node_attr['word'] == 'quick'][0]
|
||||||
|
|
||||||
|
fox_id = [node_id for node_id, node_attr in pattern.nodes.items()
|
||||||
|
if node_attr['word'] == 'fox'][0]
|
||||||
|
|
||||||
|
quick_map = pattern.adjacency[quick_id]
|
||||||
|
fox_map = pattern.adjacency[fox_id]
|
||||||
|
|
||||||
|
assert len(quick_map) == 0
|
||||||
|
assert len(fox_map) == 1
|
||||||
|
|
||||||
|
dep = fox_map[quick_id]
|
||||||
|
|
||||||
|
assert dep == 'amod'
|
||||||
|
|
||||||
|
def test_define_edge_with_regex(self):
|
||||||
|
query = "[word:quick] >/amod|nsubj/ [word:fox]"
|
||||||
|
pattern = PatternParser.parse(query)
|
||||||
|
|
||||||
|
quick_id = [node_id for node_id, node_attr in pattern.nodes.items()
|
||||||
|
if node_attr['word'] == 'quick'][0]
|
||||||
|
|
||||||
|
fox_id = [node_id for node_id, node_attr in pattern.nodes.items()
|
||||||
|
if node_attr['word'] == 'fox'][0]
|
||||||
|
|
||||||
|
fox_map = pattern.adjacency[fox_id]
|
||||||
|
dep = fox_map[quick_id]
|
||||||
|
|
||||||
|
assert dep == re.compile(r'amod|nsubj', re.U)
|
61
spacy/tests/pattern/pattern.py
Normal file
61
spacy/tests/pattern/pattern.py
Normal file
|
@ -0,0 +1,61 @@
|
||||||
|
# coding: utf-8
|
||||||
|
|
||||||
|
from ..util import get_doc
|
||||||
|
from ...pattern.pattern import Tree, DependencyTree
|
||||||
|
from ...pattern.parser import PatternParser
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import logging
|
||||||
|
logger = logging.getLogger()
|
||||||
|
logger.addHandler(logging.StreamHandler())
|
||||||
|
logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def doc(en_vocab):
|
||||||
|
words = ['I', "'m", 'going', 'to', 'the', 'zoo', 'next', 'week', '.']
|
||||||
|
doc = get_doc(en_vocab,
|
||||||
|
words=words,
|
||||||
|
deps=['nsubj', 'aux', 'ROOT', 'prep', 'det', 'pobj',
|
||||||
|
'amod', 'npadvmod', 'punct'],
|
||||||
|
heads=[2, 1, 0, -1, 1, -2, 1, -5, -6])
|
||||||
|
return doc
|
||||||
|
|
||||||
|
|
||||||
|
class TestTree:
|
||||||
|
def test_is_connected(self):
|
||||||
|
tree = Tree()
|
||||||
|
tree.add_node(1)
|
||||||
|
tree.add_node(2)
|
||||||
|
tree.add_edge(1, 2)
|
||||||
|
|
||||||
|
assert tree.is_connected()
|
||||||
|
|
||||||
|
tree.add_node(3)
|
||||||
|
assert not tree.is_connected()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDependencyTree:
|
||||||
|
def test_from_doc(self, doc):
|
||||||
|
dep_tree = DependencyTree(doc)
|
||||||
|
|
||||||
|
assert len(dep_tree) == len(doc)
|
||||||
|
assert dep_tree.is_connected()
|
||||||
|
assert dep_tree.number_of_edges() == len(doc) - 1
|
||||||
|
|
||||||
|
def test_simple_matching(self, doc):
|
||||||
|
dep_tree = DependencyTree(doc)
|
||||||
|
pattern = PatternParser.parse("""root [word:going]
|
||||||
|
to [word:to]
|
||||||
|
[word:week]=date > root
|
||||||
|
[word:/zoo|park/]=place >pobj to
|
||||||
|
to >prep root
|
||||||
|
""")
|
||||||
|
assert pattern is not None
|
||||||
|
matches = dep_tree.match(pattern)
|
||||||
|
assert len(matches) == 1
|
||||||
|
|
||||||
|
match = matches[0]
|
||||||
|
assert match['place'] == doc[5]
|
||||||
|
assert match['date'] == doc[7]
|
Loading…
Reference in New Issue
Block a user