diff --git a/spacy/compat.py b/spacy/compat.py index 1ca8a59fe..8d962976b 100644 --- a/spacy/compat.py +++ b/spacy/compat.py @@ -16,6 +16,10 @@ try: except ImportError: import copyreg as copy_reg +try: + import Queue as queue +except ImportError: + import queue is_python2 = six.PY2 is_python3 = six.PY3 @@ -32,6 +36,7 @@ if is_python2: basestring_ = basestring input_ = raw_input json_dumps = lambda data: ujson.dumps(data, indent=2).decode('utf8') + intern = intern elif is_python3: bytes_ = bytes @@ -39,6 +44,7 @@ elif is_python3: basestring_ = str input_ = input json_dumps = lambda data: ujson.dumps(data, indent=2) + intern = sys.intern def symlink_to(orig, dest): diff --git a/spacy/pattern/__init__.py b/spacy/pattern/__init__.py new file mode 100644 index 000000000..325ba04ea --- /dev/null +++ b/spacy/pattern/__init__.py @@ -0,0 +1,4 @@ +# coding: utf-8 + +from .pattern import DependencyTree +from .parser import PatternParser diff --git a/spacy/pattern/parser.py b/spacy/pattern/parser.py new file mode 100644 index 000000000..122d2b8f3 --- /dev/null +++ b/spacy/pattern/parser.py @@ -0,0 +1,377 @@ +# coding: utf-8 + +from spacy.compat import intern, queue +from spacy.strings import hash_string +from operator import itemgetter +import re +import json + +from .pattern import DependencyPattern + +TOKEN_INITIAL = intern('initial') + + +class PatternParser(object): + """Compile a Pattern query into a :class:`Pattern`, that can be used to + match :class:`DependencyTree`s.""" + whitespace_re = re.compile(r'\s+', re.U) + newline_re = re.compile(r'(\r\n|\r|\n)') + name_re = re.compile(r'\w+', re.U) + + TOKEN_BLOCK_BEGIN = '[' + TOKEN_BLOCK_END = ']' + EDGE_BLOCK_BEGIN = '>' + WHITESPACE = ' ' + + @classmethod + def parse(cls, query): + """Parse the given `query`, and compile it into a :class:`Pattern`.""" + pattern = DependencyPattern() + + for lineno, token_stream in enumerate(cls.tokenize(query)): + try: + cls._parse_line(token_stream, pattern, lineno+1) + except StopIteration: + raise SyntaxError("A token is missing, please check your " + "query.") + + if not pattern.nodes: + return + + cls.check_pattern(pattern) + return pattern + + @staticmethod + def check_pattern(pattern): + if not pattern.is_connected(): + raise ValueError("The pattern tree must be a fully connected " + "graph.") + + if pattern.root_node is None: + raise ValueError("The root node of the tree could not be found.") + + @classmethod + def _parse_line(cls, stream, pattern, lineno): + while not stream.closed: + token = stream.current + + if token.type == 'name': + next_token = stream.look() + + if next_token.type == 'node': + cls.parse_node_def(stream, pattern) + + elif next_token.type == 'edge': + cls.parse_edge_def(stream, pattern) + + else: + raise SyntaxError("line %d: A 'node' or 'edge' token must " + "follow a 'name' token." % lineno) + + elif token.type == 'node': + next_token = stream.look() + + if next_token.type == 'edge': + cls.parse_edge_def(stream, pattern) + else: + raise SyntaxError("line %d: an 'edge' token is " + "expected." % lineno) + + if not stream.closed: + next(stream) + + @classmethod + def parse_node_def(cls, stream, pattern): + name_token = stream.current + next(stream) + node_token = stream.current + cls.add_node(node_token, pattern, name_token) + + @classmethod + def add_node(cls, node_token, pattern, name_token=None): + token_name = None + if name_token is not None: + token_id = name_token.value + token_name = name_token.value + else: + token_id = node_token.hash() + + if token_id in pattern.nodes: + raise SyntaxError("Token with ID '{}' already registered.".format( + token_id)) + + token_attr = cls.parse_node_attributes(node_token.value) + token_attr['_name'] = token_name + pattern.add_node(token_id, token_attr) + + @classmethod + def parse_edge_def(cls, stream, pattern): + token = stream.current + + if token.type == 'name': + token_id = token.value + if token_id not in pattern.nodes: + raise SyntaxError("Token '{}' with ID '{}' is not " + "defined.".format(token, token_id)) + + elif token.type == 'node': + token_id = token.hash() + cls.add_node(token, pattern) + + next(stream) + edge_attr = cls.parse_edge_attributes(stream.current.value) + next(stream) + + head_token = stream.current + if head_token.type == 'name': + head_token_id = head_token.value + if head_token_id not in pattern.nodes: + raise SyntaxError("Token '{}' with ID '{}' is not " + "defined.".format(head_token, head_token_id)) + elif head_token.type == 'node': + head_token_id = head_token.hash() + cls.add_node(head_token, pattern) + else: + raise SyntaxError("A 'node' or 'name' token was expected.") + + # inverse the dependency to have an actual tree + pattern.add_edge(head_token_id, token_id, edge_attr) + + @classmethod + def parse_node_attributes(cls, string): + string = string[1:] # remove the trailing '[' + end_delimiter_idx = string.find(']') + + attr_str = string[:end_delimiter_idx] + attr = {} + + try: + attr = json.loads(attr_str) + except json.JSONDecodeError: + for pair in attr_str.split(","): + key, value = pair.split(':') + attr[key] = value + + for key, value in attr.items(): + attr[key] = cls.compile_expression(value) + + alias = string[end_delimiter_idx+2:] + + if alias: + attr['_alias'] = alias + + return attr + + @classmethod + def parse_edge_attributes(cls, string): + string = string[1:] # remove the trailing '>' + + if not string: + return None + + return cls.compile_expression(string) + + @staticmethod + def compile_expression(expr): + if expr.startswith('/') and expr.endswith('/'): + string = expr[1:-1] + return re.compile(string, re.U) + + return expr + + @classmethod + def tokenize(cls, text): + lines = text.splitlines() + + for lineno, line in enumerate(lines): + yield TokenStream(cls._tokenize_line(line, lineno+1)) + + @classmethod + def _tokenize_line(cls, line, lineno): + reader = Reader(line) + + while reader.remaining(): + char = reader.next() + + if char == cls.TOKEN_BLOCK_BEGIN: + token = 'node' + idx = reader.find(cls.TOKEN_BLOCK_END) + + if idx == -1: + raise SyntaxError("A token block end ']' was expected.") + + idx += 1 + if len(reader) > idx and reader[idx] == '=': + # The node has a name + idx = reader.find(cls.WHITESPACE, start=idx) + + if idx == -1: + idx = reader.remaining() + + elif char == cls.EDGE_BLOCK_BEGIN: + token = 'edge' + idx = reader.find(cls.WHITESPACE) + + elif cls.name_re.match(char): + token = 'name' + idx = reader.find(cls.WHITESPACE) + + if idx == -1: + whole_name_match = cls.name_re.match(str(reader)) + idx = whole_name_match.end() + + elif cls.newline_re.match(char) or cls.whitespace_re.match(char): + # skip the whitespace + reader.consume() + continue + + else: + raise SyntaxError("Unrecognized token BEGIN char: '{" + "}'".format(char)) + + if idx == -1: + raise SyntaxError("Ending character of token '{}' not " + "found.".format(token)) + value = reader.consume(idx) + + yield Token(lineno, token, value) + + +class Reader(object): + """A class used by the :class:`PatternParser` to tokenize the `text`.""" + __slots__ = ('text', 'pos') + + def __init__(self, text): + self.text = text + self.pos = 0 + + def find(self, needle, start=0, end=None): + pos = self.pos + start += pos + if end is None: + index = self.text.find(needle, start) + else: + end += pos + index = self.text.find(needle, start, end) + if index != -1: + index -= pos + return index + + def consume(self, count=1): + new_pos = self.pos + count + s = self.text[self.pos:new_pos] + self.pos = new_pos + return s + + def next(self): + return self.text[self.pos:self.pos + 1] + + def remaining(self): + return len(self.text) - self.pos + + def __len__(self): + return self.remaining() + + def __getitem__(self, key): + if key < 0: + return self.text[key] + else: + return self.text[self.pos + key] + + def __str__(self): + return self.text[self.pos:] + + +# The following classes were copied from Jinja2, a BSD-licensed project, +# and slightly modified: Token, TokenStreamIterator, TokenStream. + +class Token(tuple): + """Token class.""" + __slots__ = () + lineno, type, value = (property(itemgetter(x)) for x in range(3)) + + def __new__(cls, lineno, type, value): + return tuple.__new__(cls, (lineno, intern(str(type)), value)) + + def hash(self): + string = self.value + return hash_string(string) + + def __repr__(self): + return 'Token(%r, %r, %r)' % ( + self.lineno, + self.type, + self.value) + + +class TokenStreamIterator(object): + """The iterator for tokenstreams. Iterate over the stream until the + stream is empty. + """ + + def __init__(self, stream): + self.stream = stream + + def __iter__(self): + return self + + def __next__(self): + token = self.stream.current + try: + next(self.stream) + except StopIteration: + self.stream.close() + raise StopIteration() + + return token + + +class TokenStream(object): + """A token stream is an iterable that yields :class:`Token`s. The + current active token is stored as :attr:`current`. + """ + + def __init__(self, generator): + self._iter = iter(generator) + self._pushed = queue.deque() + self.closed = False + self.current = Token(1, TOKEN_INITIAL, '') + next(self) + + def __iter__(self): + return TokenStreamIterator(self) + + def __bool__(self): + return bool(self._pushed) + __nonzero__ = __bool__ # py2 + + def push(self, token): + """Push a token back to the stream.""" + self._pushed.append(token) + + def look(self): + """Look at the next token.""" + old_token = next(self) + result = self.current + self.push(result) + self.current = old_token + return result + + def __next__(self): + """Go one token ahead and return the old one.""" + rv = self.current + if self._pushed: + self.current = self._pushed.popleft() + else: + if self.closed: + raise StopIteration("No token left.") + try: + self.current = next(self._iter) + except StopIteration: + self.close() + return rv + + def close(self): + """Close the stream.""" + self._iter = None + self.closed = True diff --git a/spacy/pattern/pattern.py b/spacy/pattern/pattern.py new file mode 100644 index 000000000..552283066 --- /dev/null +++ b/spacy/pattern/pattern.py @@ -0,0 +1,318 @@ +# coding: utf-8 + +import logging +from collections import defaultdict + + +logger = logging.getLogger(__name__) + + +class Tree(object): + def __init__(self): + self.adjacency = defaultdict(dict) + self.nodes = {} + + def __getitem__(self, item): + return self.nodes[item] + + def add_node(self, node, attr_dict=None): + attr_dict = attr_dict or {} + self.nodes[node] = attr_dict + + def add_edge(self, u, v, dep=None): + if u not in self.nodes or v not in self.nodes: + raise ValueError("Each node must be defined before adding an edge.") + + self.adjacency[u][v] = dep + + def number_of_nodes(self): + return len(self) + + def __len__(self): + return len(self.nodes) + + def number_of_edges(self): + return sum(len(adj_dict) for adj_dict in self.adjacency.values()) + + def edges_iter(self, origin=None, data=True): + nbunch = (self.adjacency.items() if origin is None + else [(origin, self.adjacency[origin])]) + + for u, nodes in nbunch: + for v, dep in nodes.items(): + if data: + yield (u, v, dep) + else: + yield (u, v) + + def nodes_iter(self): + for node in self.nodes.keys(): + yield node + + def is_connected(self): + if len(self) == 0: + raise ValueError('Connectivity is undefined for the null graph.') + return len(set(self._plain_bfs(next(self.nodes_iter()), + undirected=True))) == len(self) + + def _plain_bfs(self, source, undirected=False): + """A fast BFS node generator. + :param: source: the source node + """ + seen = set() + next_level = {source} + while next_level: + this_level = next_level + next_level = set() + for v in this_level: + if v not in seen: + yield v + seen.add(v) + next_level.update(self.adjacency[v].keys()) + + if undirected: + for n, adj in self.adjacency.items(): + if v in adj.keys(): + next_level.add(n) + + +class DependencyPattern(Tree): + @property + def root_node(self): + if self.number_of_nodes() == 1: + # if the graph has a single node, it is the root + return next(iter(self.nodes.keys())) + + if not self.is_connected(): + return None + + in_node = set() + out_node = set() + for u, v in self.edges_iter(data=False): + in_node.add(v) + out_node.add(u) + + try: + return list(out_node.difference(in_node))[0] + except IndexError: + return None + + +class DependencyTree(Tree): + def __init__(self, doc): + super(DependencyTree, self).__init__() + + for token in doc: + self.nodes[token.i] = token + + if token.head.i != token.i: + # inverse the dependency to have an actual tree + self.adjacency[token.head.i][token.i] = token.dep_ + + def __getitem__(self, item): + return self.nodes[item] + + def match_nodes(self, attr_dict, **kwargs): + results = [] + for token_idx, token in self.nodes.items(): + if match_token(token, attr_dict, **kwargs): + results.append(token_idx) + + return results + + def match(self, pattern): + """Return a list of matches between the given + :class:`DependencyPattern` and `self` if any, or None. + + :param pattern: a :class:`DependencyPattern` + """ + pattern_root_node = pattern.root_node + pattern_root_node_attr = pattern[pattern_root_node] + dep_root_nodes = self.match_nodes(pattern_root_node_attr) + + if not dep_root_nodes: + logger.debug("No node matches the pattern root " + "'{}'".format(pattern_root_node_attr)) + + matches = [] + for candidate_root_node in dep_root_nodes: + match_list = subtree_in_graph(candidate_root_node, self, + pattern_root_node, pattern) + for mapping in match_list: + match = PatternMatch(mapping, pattern, self) + matches.append(match) + + return matches + + +class PatternMatch(object): + def __init__(self, mapping, pattern, tree): + for pattern_node_id, tree_node_id in mapping.items(): + mapping[pattern_node_id] = tree[tree_node_id] + self.mapping = mapping + self.pattern = pattern + self.tree = tree + + self.alias_map = {} + for pattern_node_id in self.mapping: + pattern_node = self.pattern[pattern_node_id] + + alias = pattern_node.get('_alias') + if alias: + self.alias_map[alias] = self.mapping[pattern_node_id] + + def __repr__(self): + return "".format(len(self.mapping)) + + def __getitem__(self, item): + return self.alias_map[item] + + +def subtree_in_graph(dep_tree_node, dep_tree, pattern_node, pattern): + """Return a list of matches of `pattern` as a subtree of `dep_tree`. + :param dep_tree_node: the token (identified by its index) to start from + (int) + :param dep_tree: a :class:`DependencyTree` + :param pattern_node: the pattern node to start from + :param pattern: a :class:`DependencyPattern` + :return: found matches (list) + """ + results = [] + association_dict = {pattern_node: dep_tree_node} + _subtree_in_graph(dep_tree_node, dep_tree, pattern_node, + pattern, results=results, + association_dict=association_dict) + results = results or [] + return results + + +def _subtree_in_graph(dep_tree_node, dep_tree, pattern_node, pattern, + association_dict=None, results=None): + token = dep_tree[dep_tree_node] + logger.debug("Starting from token '{}'".format(token.orth_)) + + adjacent_edges = list(pattern.edges_iter(origin=pattern_node)) + if adjacent_edges: + for (_, adjacent_pattern_node, + dep) in adjacent_edges: + adjacent_pattern_node_attr = pattern[adjacent_pattern_node] + logger.debug("Exploring relation {} -[{}]-> {} from " + "pattern".format(pattern[pattern_node], + dep, + adjacent_pattern_node_attr)) + + adjacent_nodes = find_adjacent_nodes(dep_tree, + dep_tree_node, + dep, + adjacent_pattern_node_attr) + + if not adjacent_nodes: + logger.debug("No adjacent nodes in dep_tree satisfying these " + "conditions.") + return None + + for adjacent_node in adjacent_nodes: + logger.debug("Found adjacent node '{}' in " + "dep_tree".format(dep_tree[adjacent_node].orth_)) + association_dict[adjacent_pattern_node] = adjacent_node + recursive_return = _subtree_in_graph(adjacent_node, + dep_tree, + adjacent_pattern_node, + pattern, + association_dict, + results=results) + + if recursive_return is None: + # No Match + return None + + association_dict, results = recursive_return + + else: + if len(association_dict) == pattern.number_of_nodes(): + logger.debug("Add to results: {}".format(association_dict)) + results.append(dict(association_dict)) + + else: + logger.debug("{} nodes in subgraph, only {} " + "mapped".format(pattern.number_of_nodes(), + len(association_dict))) + + logger.debug("Return intermediate: {}".format(association_dict)) + return association_dict, results + + +def find_adjacent_nodes(dep_tree, node, target_dep, node_attributes): + """Find nodes adjacent to ``node`` that fulfill specified attributes + values on edge and node. + + :param dep_tree: a :class:`DependencyTree` + :param node: initial node to search from + :param target_dep: edge attributes that must be fulfilled (pair-value) + :type target_dep: dict + :param node_attributes: node attributes that must be fulfilled (pair-value) + :type node_attributes: dict + :return: adjacent nodes that fulfill the given criteria (list) + """ + results = [] + for _, adj_node, adj_dep in dep_tree.edges_iter(origin=node): + adj_token = dep_tree[adj_node] + if (match_edge(adj_dep, target_dep) + and match_token(adj_token, node_attributes)): + results.append(adj_node) + + return results + + +def match_edge(token_dep, target_dep): + if target_dep is None: + return True + + if hasattr(target_dep, 'match'): + return target_dep.match(token_dep) is not None + + if token_dep == target_dep: + return True + + return False + + +def match_token(token, + target_attributes, + ignore_special_key=True, + lower=True): + bind_map = { + 'word': lambda t: t.orth_, + 'lemma': lambda t: t.lemma_, + 'ent': lambda t: t.ent_type_, + } + + for target_key, target_value in target_attributes.items(): + is_special_key = target_key[0] == '_' + + if ignore_special_key and is_special_key: + continue + + if lower and hasattr(target_value, 'lower'): + target_value = target_value.lower() + + if target_key in bind_map: + token_attr = bind_map[target_key](token) + + if lower: + token_attr = token_attr.lower() + + if hasattr(target_value, 'match'): # if it is a compiled regex + if target_value.match(token_attr) is None: + break + else: + if not token_attr == target_value: + break + + else: + raise ValueError("Unknown key: '{}'".format(target_key)) + + else: # the loop was not broken + return True + + return False diff --git a/spacy/tests/pattern/__init__.py b/spacy/tests/pattern/__init__.py new file mode 100644 index 000000000..57d631c3f --- /dev/null +++ b/spacy/tests/pattern/__init__.py @@ -0,0 +1 @@ +# coding: utf-8 diff --git a/spacy/tests/pattern/parser.py b/spacy/tests/pattern/parser.py new file mode 100644 index 000000000..50dd3ac60 --- /dev/null +++ b/spacy/tests/pattern/parser.py @@ -0,0 +1,76 @@ +# coding: utf-8 + + +import re +from ...pattern.parser import PatternParser + + +class TestPatternParser: + def test_empty_query(self): + assert PatternParser.parse('') is None + assert PatternParser.parse(' ') is None + + def test_define_node(self): + query = "fox [lemma:fox,word:fox]=alias" + pattern = PatternParser.parse(query) + + assert pattern is not None + assert pattern.number_of_nodes() == 1 + assert pattern.number_of_edges() == 0 + + assert 'fox' in pattern.nodes + + attrs = pattern['fox'] + assert attrs.get('lemma') == 'fox' + assert attrs.get('word') == 'fox' + assert attrs.get('_name') == 'fox' + assert attrs.get('_alias') == 'alias' + + for adj_list in pattern.adjacency.values(): + assert not adj_list + + def test_define_node_with_regex(self): + query = "fox [lemma:/fo.*/]" + pattern = PatternParser.parse(query) + + attrs = pattern['fox'] + assert attrs.get('lemma') == re.compile(r'fo.*', re.U) + + def test_define_edge(self): + query = "[word:quick] >amod [word:fox]" + pattern = PatternParser.parse(query) + + assert pattern is not None + assert pattern.number_of_nodes() == 2 + assert pattern.number_of_edges() == 1 + + quick_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'quick'][0] + + fox_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'fox'][0] + + quick_map = pattern.adjacency[quick_id] + fox_map = pattern.adjacency[fox_id] + + assert len(quick_map) == 0 + assert len(fox_map) == 1 + + dep = fox_map[quick_id] + + assert dep == 'amod' + + def test_define_edge_with_regex(self): + query = "[word:quick] >/amod|nsubj/ [word:fox]" + pattern = PatternParser.parse(query) + + quick_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'quick'][0] + + fox_id = [node_id for node_id, node_attr in pattern.nodes.items() + if node_attr['word'] == 'fox'][0] + + fox_map = pattern.adjacency[fox_id] + dep = fox_map[quick_id] + + assert dep == re.compile(r'amod|nsubj', re.U) diff --git a/spacy/tests/pattern/pattern.py b/spacy/tests/pattern/pattern.py new file mode 100644 index 000000000..a476f92f7 --- /dev/null +++ b/spacy/tests/pattern/pattern.py @@ -0,0 +1,61 @@ +# coding: utf-8 + +from ..util import get_doc +from ...pattern.pattern import Tree, DependencyTree +from ...pattern.parser import PatternParser + +import pytest + +import logging +logger = logging.getLogger() +logger.addHandler(logging.StreamHandler()) +logger.setLevel(logging.DEBUG) + + +@pytest.fixture +def doc(en_vocab): + words = ['I', "'m", 'going', 'to', 'the', 'zoo', 'next', 'week', '.'] + doc = get_doc(en_vocab, + words=words, + deps=['nsubj', 'aux', 'ROOT', 'prep', 'det', 'pobj', + 'amod', 'npadvmod', 'punct'], + heads=[2, 1, 0, -1, 1, -2, 1, -5, -6]) + return doc + + +class TestTree: + def test_is_connected(self): + tree = Tree() + tree.add_node(1) + tree.add_node(2) + tree.add_edge(1, 2) + + assert tree.is_connected() + + tree.add_node(3) + assert not tree.is_connected() + + +class TestDependencyTree: + def test_from_doc(self, doc): + dep_tree = DependencyTree(doc) + + assert len(dep_tree) == len(doc) + assert dep_tree.is_connected() + assert dep_tree.number_of_edges() == len(doc) - 1 + + def test_simple_matching(self, doc): + dep_tree = DependencyTree(doc) + pattern = PatternParser.parse("""root [word:going] + to [word:to] + [word:week]=date > root + [word:/zoo|park/]=place >pobj to + to >prep root + """) + assert pattern is not None + matches = dep_tree.match(pattern) + assert len(matches) == 1 + + match = matches[0] + assert match['place'] == doc[5] + assert match['date'] == doc[7]