# cython: infer_types=True # cython: profile=True from __future__ import unicode_literals from cymem.cymem cimport Pool from preshed.maps cimport PreshMap from .matcher cimport Matcher from ..vocab cimport Vocab from ..tokens.doc cimport Doc from .matcher import unpickle_matcher from ..errors import Errors DELIMITER = '||' INDEX_HEAD = 1 INDEX_RELOP = 0 cdef class DependencyTreeMatcher: """Match dependency parse tree based on pattern rules.""" cdef Pool mem cdef readonly Vocab vocab cdef readonly Matcher token_matcher cdef public object _patterns cdef public object _keys_to_token cdef public object _root cdef public object _entities cdef public object _callbacks cdef public object _nodes cdef public object _tree def __init__(self, vocab): """Create the DependencyTreeMatcher. vocab (Vocab): The vocabulary object, which must be shared with the documents the matcher will operate on. RETURNS (DependencyTreeMatcher): The newly constructed object. """ size = 20 self.token_matcher = Matcher(vocab) self._keys_to_token = {} self._patterns = {} self._root = {} self._nodes = {} self._tree = {} self._entities = {} self._callbacks = {} self.vocab = vocab self.mem = Pool() def __reduce__(self): data = (self.vocab, self._patterns,self._tree, self._callbacks) return (unpickle_matcher, data, None, None) def __len__(self): """Get the number of rules, which are edges ,added to the dependency tree matcher. RETURNS (int): The number of rules. """ return len(self._patterns) def __contains__(self, key): """Check whether the matcher contains rules for a match ID. key (unicode): The match ID. RETURNS (bool): Whether the matcher contains rules for this match ID. """ return self._normalize_key(key) in self._patterns def validateInput(self, pattern, key): idx = 0 visitedNodes = {} for relation in pattern: if 'PATTERN' not in relation or 'SPEC' not in relation: raise ValueError(Errors.E098.format(key=key)) if idx == 0: if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' not in relation['SPEC'] and 'NBOR_NAME' not in relation['SPEC']): raise ValueError(Errors.E099.format(key=key)) visitedNodes[relation['SPEC']['NODE_NAME']] = True else: if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' in relation['SPEC'] and 'NBOR_NAME' in relation['SPEC']): raise ValueError(Errors.E100.format(key=key)) if relation['SPEC']['NODE_NAME'] in visitedNodes or relation['SPEC']['NBOR_NAME'] not in visitedNodes: raise ValueError(Errors.E101.format(key=key)) visitedNodes[relation['SPEC']['NODE_NAME']] = True visitedNodes[relation['SPEC']['NBOR_NAME']] = True idx = idx + 1 def add(self, key, on_match, *patterns): for pattern in patterns: if len(pattern) == 0: raise ValueError(Errors.E012.format(key=key)) self.validateInput(pattern,key) key = self._normalize_key(key) _patterns = [] for pattern in patterns: token_patterns = [] for i in range(len(pattern)): token_pattern = [pattern[i]['PATTERN']] token_patterns.append(token_pattern) # self.patterns.append(token_patterns) _patterns.append(token_patterns) self._patterns.setdefault(key, []) self._callbacks[key] = on_match self._patterns[key].extend(_patterns) # Add each node pattern of all the input patterns individually to the matcher. # This enables only a single instance of Matcher to be used. # Multiple adds are required to track each node pattern. _keys_to_token_list = [] for i in range(len(_patterns)): _keys_to_token = {} # TODO : Better ways to hash edges in pattern? for j in range(len(_patterns[i])): k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j)) self.token_matcher.add(k,None,_patterns[i][j]) _keys_to_token[k] = j _keys_to_token_list.append(_keys_to_token) self._keys_to_token.setdefault(key, []) self._keys_to_token[key].extend(_keys_to_token_list) _nodes_list = [] for pattern in patterns: nodes = {} for i in range(len(pattern)): nodes[pattern[i]['SPEC']['NODE_NAME']]=i _nodes_list.append(nodes) self._nodes.setdefault(key, []) self._nodes[key].extend(_nodes_list) # Create an object tree to traverse later on. # This datastructure enable easy tree pattern match. # Doc-Token based tree cannot be reused since it is memory heavy and # tightly coupled with doc self.retrieve_tree(patterns,_nodes_list,key) def retrieve_tree(self,patterns,_nodes_list,key): _heads_list = [] _root_list = [] for i in range(len(patterns)): heads = {} root = -1 for j in range(len(patterns[i])): token_pattern = patterns[i][j] if('NBOR_RELOP' not in token_pattern['SPEC']): heads[j] = ('root',j) root = j else: heads[j] = (token_pattern['SPEC']['NBOR_RELOP'],_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]) _heads_list.append(heads) _root_list.append(root) _tree_list = [] for i in range(len(patterns)): tree = {} for j in range(len(patterns[i])): if(_heads_list[i][j][INDEX_HEAD] == j): continue head = _heads_list[i][j][INDEX_HEAD] if(head not in tree): tree[head] = [] tree[head].append( (_heads_list[i][j][INDEX_RELOP],j) ) _tree_list.append(tree) self._tree.setdefault(key, []) self._tree[key].extend(_tree_list) self._root.setdefault(key, []) self._root[key].extend(_root_list) def has_key(self, key): """Check whether the matcher has a rule with a given key. key (string or int): The key to check. RETURNS (bool): Whether the matcher has the rule. """ key = self._normalize_key(key) return key in self._patterns def get(self, key, default=None): """Retrieve the pattern stored for a key. key (unicode or int): The key to retrieve. RETURNS (tuple): The rule, as an (on_match, patterns) tuple. """ key = self._normalize_key(key) if key not in self._patterns: return default return (self._callbacks[key], self._patterns[key]) def __call__(self, Doc doc): matched_trees = [] matches = self.token_matcher(doc) for key in list(self._patterns.keys()): _patterns_list = self._patterns[key] _keys_to_token_list = self._keys_to_token[key] _root_list = self._root[key] _tree_list = self._tree[key] _nodes_list = self._nodes[key] length = len(_patterns_list) for i in range(length): _keys_to_token = _keys_to_token_list[i] _root = _root_list[i] _tree = _tree_list[i] _nodes = _nodes_list[i] id_to_position = {} for i in range(len(_nodes)): id_to_position[i]=[] # This could be taken outside to improve running time..? for match_id, start, end in matches: if match_id in _keys_to_token: id_to_position[_keys_to_token[match_id]].append(start) _node_operator_map = self.get_node_operator_map(doc,_tree,id_to_position,_nodes,_root) length = len(_nodes) if _root in id_to_position: candidates = id_to_position[_root] for candidate in candidates: isVisited = {} self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited,_node_operator_map) # To check if the subtree pattern is completely identified. This is a heuristic. # This is done to reduce the complexity of exponential unordered subtree matching. # Will give approximate matches in some cases. if(len(isVisited) == length): matched_trees.append((key,list(isVisited))) for i, (ent_id, nodes) in enumerate(matched_trees): on_match = self._callbacks.get(ent_id) if on_match is not None: on_match(self, doc, i, matches) return matched_trees def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map): if(root in id_to_position and candidate in id_to_position[root]): # color the node since it is valid isVisited[candidate] = True if root in tree: for root_child in tree[root]: if candidate in _node_operator_map and root_child[INDEX_RELOP] in _node_operator_map[candidate]: candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]] for candidate_child in candidate_children: result = self.dfs( candidate_child.i, root_child[INDEX_HEAD], tree, id_to_position, doc, isVisited, _node_operator_map ) # Given a node and an edge operator, to return the list of nodes # from the doc that belong to node+operator. This is used to store # all the results beforehand to prevent unnecessary computation while # pattern matching # _node_operator_map[node][operator] = [...] def get_node_operator_map(self,doc,tree,id_to_position,nodes,root): _node_operator_map = {} all_node_indices = nodes.values() all_operators = [] for node in all_node_indices: if node in tree: for child in tree[node]: all_operators.append(child[INDEX_RELOP]) all_operators = list(set(all_operators)) all_nodes = [] for node in all_node_indices: all_nodes = all_nodes + id_to_position[node] all_nodes = list(set(all_nodes)) for node in all_nodes: _node_operator_map[node] = {} for operator in all_operators: _node_operator_map[node][operator] = [] # Used to invoke methods for each operator switcher = { '<':self.dep, '>':self.gov, '>>':self.dep_chain, '<<':self.gov_chain, '.':self.imm_precede, '$+':self.imm_right_sib, '$-':self.imm_left_sib, '$++':self.right_sib, '$--':self.left_sib } for operator in all_operators: for node in all_nodes: _node_operator_map[node][operator] = switcher.get(operator)(doc,node) return _node_operator_map def dep(self,doc,node): return list(doc[node].head) def gov(self,doc,node): return list(doc[node].children) def dep_chain(self,doc,node): return list(doc[node].ancestors) def gov_chain(self,doc,node): return list(doc[node].subtree) def imm_precede(self,doc,node): if node>0: return [doc[node-1]] return [] def imm_right_sib(self,doc,node): for idx in range(list(doc[node].head.children)): if idx == node-1: return [doc[idx]] return [] def imm_left_sib(self,doc,node): for idx in range(list(doc[node].head.children)): if idx == node+1: return [doc[idx]] return [] def right_sib(self,doc,node): candidate_children = [] for idx in range(list(doc[node].head.children)): if idx < node: candidate_children.append(doc[idx]) return candidate_children def left_sib(self,doc,node): candidate_children = [] for idx in range(list(doc[node].head.children)): if idx > node: candidate_children.append(doc[idx]) return candidate_children def _normalize_key(self, key): if isinstance(key, basestring): return self.vocab.strings.add(key) else: return key