diff --git a/spacy/errors.py b/spacy/errors.py index b17597110..6b45e4815 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -261,8 +261,13 @@ class Errors(object): "Span objects, or dicts if set to manual=True.") E097 = ("Invalid pattern: expected token pattern (list of dicts) or " "phrase pattern (string) but got:\n{pattern}") - - + E098 = ("Invalid pattern specified: expected both SPEC and PATTERN.") + E099 = ("First node of pattern should be a root node. The root should " + "only contain NODE_NAME.") + E100 = ("Nodes apart from the root should contain NODE_NAME, NBOR_NAME and " + "NBOR_RELOP.") + E101 = ("NODE_NAME should be a new node and NBOR_NAME should already have " + "have been declared in previous edges.") @add_codes class TempErrors(object): T001 = ("Max length currently 10 for phrase matching") diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index e8d567428..c5db8ac39 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -29,7 +29,8 @@ from .attrs import FLAG41 as I4_ENT DELIMITER = '||' - +INDEX_HEAD = 1 +INDEX_RELOP = 0 cdef enum action_t: REJECT = 0000 @@ -731,18 +732,31 @@ cdef class DependencyTreeMatcher: """ return self._normalize_key(key) in self._patterns + def validateInput(self, pattern, key): + idx = 0 + visitedNodes = {} + for relation in pattern: + if 'PATTERN' not in relation or 'SPEC' not in relation: + raise ValueError(Errors.E098.format(key=key)) + if idx == 0: + if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' not in relation['SPEC'] and 'NBOR_NAME' not in relation['SPEC']): + raise ValueError(Errors.E099.format(key=key)) + visitedNodes[relation['SPEC']['NODE_NAME']] = True + else: + if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' in relation['SPEC'] and 'NBOR_NAME' in relation['SPEC']): + raise ValueError(Errors.E100.format(key=key)) + if relation['SPEC']['NODE_NAME'] in visitedNodes or relation['SPEC']['NBOR_NAME'] not in visitedNodes: + raise ValueError(Errors.E101.format(key=key)) + visitedNodes[relation['SPEC']['NODE_NAME']] = True + visitedNodes[relation['SPEC']['NBOR_NAME']] = True + idx = idx + 1 def add(self, key, on_match, *patterns): - # TODO : validations - # 1. check if input pattern is connected - # 2. check if pattern format is correct - # 3. check if atleast one root node is present - # 4. check if node names are not repeated - # 5. check if each node has only one head for pattern in patterns: if len(pattern) == 0: raise ValueError(Errors.E012.format(key=key)) + self.validateInput(pattern,key) key = self._normalize_key(key) @@ -792,7 +806,6 @@ cdef class DependencyTreeMatcher: self.retrieve_tree(patterns,_nodes_list,key) def retrieve_tree(self,patterns,_nodes_list,key): - _heads_list = [] _root_list = [] for i in range(len(patterns)): @@ -801,16 +814,10 @@ cdef class DependencyTreeMatcher: for j in range(len(patterns[i])): token_pattern = patterns[i][j] if('NBOR_RELOP' not in token_pattern['SPEC']): - heads[j] = j + heads[j] = ('root',j) root = j else: - # TODO: Add semgrex rules - # 1. > - if(token_pattern['SPEC']['NBOR_RELOP'] == '>'): - heads[j] = _nodes_list[i][token_pattern['SPEC']['NBOR_NAME']] - # 2. < - if(token_pattern['SPEC']['NBOR_RELOP'] == '<'): - heads[_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]] = j + heads[j] = (token_pattern['SPEC']['NBOR_RELOP'],_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]) _heads_list.append(heads) _root_list.append(root) @@ -819,12 +826,13 @@ cdef class DependencyTreeMatcher: for i in range(len(patterns)): tree = {} for j in range(len(patterns[i])): - if(j == _heads_list[i][j]): + if(_heads_list[i][j][INDEX_HEAD] == j): continue - head = _heads_list[i][j] + + head = _heads_list[i][j][INDEX_HEAD] if(head not in tree): tree[head] = [] - tree[head].append(j) + tree[head].append( (_heads_list[i][j][INDEX_RELOP],j) ) _tree_list.append(tree) self._tree.setdefault(key, []) @@ -869,23 +877,25 @@ cdef class DependencyTreeMatcher: _root = _root_list[i] _tree = _tree_list[i] _nodes = _nodes_list[i] - id_to_position = {} + for i in range(len(_nodes)): + id_to_position[i]=[] # This could be taken outside to improve running time..? for match_id, start, end in matches: if match_id in _keys_to_token: - if _keys_to_token[match_id] not in id_to_position: - id_to_position[_keys_to_token[match_id]] = [] id_to_position[_keys_to_token[match_id]].append(start) + _node_operator_map = self.get_node_operator_map(doc,_tree,id_to_position,_nodes,_root) length = len(_nodes) if _root in id_to_position: candidates = id_to_position[_root] for candidate in candidates: isVisited = {} - self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited) - # to check if the subtree pattern is completely identified + self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited,_node_operator_map) + # To check if the subtree pattern is completely identified. This is a heuristic. + # This is done to reduce the complexity of exponential unordered subtree matching. + # Will give approximate matches in some cases. if(len(isVisited) == length): matched_trees.append((key,list(isVisited))) @@ -896,22 +906,110 @@ cdef class DependencyTreeMatcher: return matched_trees - def dfs(self,candidate,root,tree,id_to_position,doc,isVisited): + def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map): if(root in id_to_position and candidate in id_to_position[root]): # color the node since it is valid isVisited[candidate] = True - candidate_children = doc[candidate].children - for candidate_child in candidate_children: - if root in tree: - for root_child in tree[root]: - self.dfs( - candidate_child.i, - root_child, - tree, - id_to_position, - doc, - isVisited - ) + if root in tree: + for root_child in tree[root]: + if candidate in _node_operator_map and root_child[INDEX_RELOP] in _node_operator_map[candidate]: + candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]] + for candidate_child in candidate_children: + result = self.dfs( + candidate_child.i, + root_child[INDEX_HEAD], + tree, + id_to_position, + doc, + isVisited, + _node_operator_map + ) + + # Given a node and an edge operator, to return the list of nodes + # from the doc that belong to node+operator. This is used to store + # all the results beforehand to prevent unnecessary computation while + # pattern matching + # _node_operator_map[node][operator] = [...] + def get_node_operator_map(self,doc,tree,id_to_position,nodes,root): + _node_operator_map = {} + all_node_indices = nodes.values() + all_operators = [] + for node in all_node_indices: + if node in tree: + for child in tree[node]: + all_operators.append(child[INDEX_RELOP]) + all_operators = list(set(all_operators)) + + all_nodes = [] + for node in all_node_indices: + all_nodes = all_nodes + id_to_position[node] + all_nodes = list(set(all_nodes)) + + for node in all_nodes: + _node_operator_map[node] = {} + for operator in all_operators: + _node_operator_map[node][operator] = [] + + # Used to invoke methods for each operator + switcher = { + '<':self.dep, + '>':self.gov, + '>>':self.dep_chain, + '<<':self.gov_chain, + '.':self.imm_precede, + '$+':self.imm_right_sib, + '$-':self.imm_left_sib, + '$++':self.right_sib, + '$--':self.left_sib + } + for operator in all_operators: + for node in all_nodes: + _node_operator_map[node][operator] = switcher.get(operator)(doc,node) + + return _node_operator_map + + def dep(self,doc,node): + return list(doc[node].head) + + def gov(self,doc,node): + return list(doc[node].children) + + def dep_chain(self,doc,node): + return list(doc[node].ancestors) + + def gov_chain(self,doc,node): + return list(doc[node].subtree) + + def imm_precede(self,doc,node): + if node>0: + return [doc[node-1]] + return [] + + def imm_right_sib(self,doc,node): + for idx in range(list(doc[node].head.children)): + if idx == node-1: + return [doc[idx]] + return [] + + def imm_left_sib(self,doc,node): + for idx in range(list(doc[node].head.children)): + if idx == node+1: + return [doc[idx]] + return [] + + def right_sib(self,doc,node): + candidate_children = [] + for idx in range(list(doc[node].head.children)): + if idx < node: + candidate_children.append(doc[idx]) + return candidate_children + + def left_sib(self,doc,node): + candidate_children = [] + for idx in range(list(doc[node].head.children)): + if idx > node: + candidate_children.append(doc[idx]) + return candidate_children def _normalize_key(self, key): if isinstance(key, basestring):