mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Added more constucts for dependency tree matcher (#2836)
This commit is contained in:
parent
817e1fc5e5
commit
0bf14082a4
|
@ -261,8 +261,13 @@ class Errors(object):
|
|||
"Span objects, or dicts if set to manual=True.")
|
||||
E097 = ("Invalid pattern: expected token pattern (list of dicts) or "
|
||||
"phrase pattern (string) but got:\n{pattern}")
|
||||
|
||||
|
||||
E098 = ("Invalid pattern specified: expected both SPEC and PATTERN.")
|
||||
E099 = ("First node of pattern should be a root node. The root should "
|
||||
"only contain NODE_NAME.")
|
||||
E100 = ("Nodes apart from the root should contain NODE_NAME, NBOR_NAME and "
|
||||
"NBOR_RELOP.")
|
||||
E101 = ("NODE_NAME should be a new node and NBOR_NAME should already have "
|
||||
"have been declared in previous edges.")
|
||||
@add_codes
|
||||
class TempErrors(object):
|
||||
T001 = ("Max length currently 10 for phrase matching")
|
||||
|
|
|
@ -29,7 +29,8 @@ from .attrs import FLAG41 as I4_ENT
|
|||
|
||||
|
||||
DELIMITER = '||'
|
||||
|
||||
INDEX_HEAD = 1
|
||||
INDEX_RELOP = 0
|
||||
|
||||
cdef enum action_t:
|
||||
REJECT = 0000
|
||||
|
@ -731,18 +732,31 @@ cdef class DependencyTreeMatcher:
|
|||
"""
|
||||
return self._normalize_key(key) in self._patterns
|
||||
|
||||
def validateInput(self, pattern, key):
|
||||
idx = 0
|
||||
visitedNodes = {}
|
||||
for relation in pattern:
|
||||
if 'PATTERN' not in relation or 'SPEC' not in relation:
|
||||
raise ValueError(Errors.E098.format(key=key))
|
||||
if idx == 0:
|
||||
if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' not in relation['SPEC'] and 'NBOR_NAME' not in relation['SPEC']):
|
||||
raise ValueError(Errors.E099.format(key=key))
|
||||
visitedNodes[relation['SPEC']['NODE_NAME']] = True
|
||||
else:
|
||||
if not('NODE_NAME' in relation['SPEC'] and 'NBOR_RELOP' in relation['SPEC'] and 'NBOR_NAME' in relation['SPEC']):
|
||||
raise ValueError(Errors.E100.format(key=key))
|
||||
if relation['SPEC']['NODE_NAME'] in visitedNodes or relation['SPEC']['NBOR_NAME'] not in visitedNodes:
|
||||
raise ValueError(Errors.E101.format(key=key))
|
||||
visitedNodes[relation['SPEC']['NODE_NAME']] = True
|
||||
visitedNodes[relation['SPEC']['NBOR_NAME']] = True
|
||||
idx = idx + 1
|
||||
|
||||
def add(self, key, on_match, *patterns):
|
||||
# TODO : validations
|
||||
# 1. check if input pattern is connected
|
||||
# 2. check if pattern format is correct
|
||||
# 3. check if atleast one root node is present
|
||||
# 4. check if node names are not repeated
|
||||
# 5. check if each node has only one head
|
||||
|
||||
for pattern in patterns:
|
||||
if len(pattern) == 0:
|
||||
raise ValueError(Errors.E012.format(key=key))
|
||||
self.validateInput(pattern,key)
|
||||
|
||||
key = self._normalize_key(key)
|
||||
|
||||
|
@ -792,7 +806,6 @@ cdef class DependencyTreeMatcher:
|
|||
self.retrieve_tree(patterns,_nodes_list,key)
|
||||
|
||||
def retrieve_tree(self,patterns,_nodes_list,key):
|
||||
|
||||
_heads_list = []
|
||||
_root_list = []
|
||||
for i in range(len(patterns)):
|
||||
|
@ -801,16 +814,10 @@ cdef class DependencyTreeMatcher:
|
|||
for j in range(len(patterns[i])):
|
||||
token_pattern = patterns[i][j]
|
||||
if('NBOR_RELOP' not in token_pattern['SPEC']):
|
||||
heads[j] = j
|
||||
heads[j] = ('root',j)
|
||||
root = j
|
||||
else:
|
||||
# TODO: Add semgrex rules
|
||||
# 1. >
|
||||
if(token_pattern['SPEC']['NBOR_RELOP'] == '>'):
|
||||
heads[j] = _nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]
|
||||
# 2. <
|
||||
if(token_pattern['SPEC']['NBOR_RELOP'] == '<'):
|
||||
heads[_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]] = j
|
||||
heads[j] = (token_pattern['SPEC']['NBOR_RELOP'],_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']])
|
||||
|
||||
_heads_list.append(heads)
|
||||
_root_list.append(root)
|
||||
|
@ -819,12 +826,13 @@ cdef class DependencyTreeMatcher:
|
|||
for i in range(len(patterns)):
|
||||
tree = {}
|
||||
for j in range(len(patterns[i])):
|
||||
if(j == _heads_list[i][j]):
|
||||
if(_heads_list[i][j][INDEX_HEAD] == j):
|
||||
continue
|
||||
head = _heads_list[i][j]
|
||||
|
||||
head = _heads_list[i][j][INDEX_HEAD]
|
||||
if(head not in tree):
|
||||
tree[head] = []
|
||||
tree[head].append(j)
|
||||
tree[head].append( (_heads_list[i][j][INDEX_RELOP],j) )
|
||||
_tree_list.append(tree)
|
||||
|
||||
self._tree.setdefault(key, [])
|
||||
|
@ -869,23 +877,25 @@ cdef class DependencyTreeMatcher:
|
|||
_root = _root_list[i]
|
||||
_tree = _tree_list[i]
|
||||
_nodes = _nodes_list[i]
|
||||
|
||||
id_to_position = {}
|
||||
for i in range(len(_nodes)):
|
||||
id_to_position[i]=[]
|
||||
|
||||
# This could be taken outside to improve running time..?
|
||||
for match_id, start, end in matches:
|
||||
if match_id in _keys_to_token:
|
||||
if _keys_to_token[match_id] not in id_to_position:
|
||||
id_to_position[_keys_to_token[match_id]] = []
|
||||
id_to_position[_keys_to_token[match_id]].append(start)
|
||||
|
||||
_node_operator_map = self.get_node_operator_map(doc,_tree,id_to_position,_nodes,_root)
|
||||
length = len(_nodes)
|
||||
if _root in id_to_position:
|
||||
candidates = id_to_position[_root]
|
||||
for candidate in candidates:
|
||||
isVisited = {}
|
||||
self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited)
|
||||
# to check if the subtree pattern is completely identified
|
||||
self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited,_node_operator_map)
|
||||
# To check if the subtree pattern is completely identified. This is a heuristic.
|
||||
# This is done to reduce the complexity of exponential unordered subtree matching.
|
||||
# Will give approximate matches in some cases.
|
||||
if(len(isVisited) == length):
|
||||
matched_trees.append((key,list(isVisited)))
|
||||
|
||||
|
@ -896,22 +906,110 @@ cdef class DependencyTreeMatcher:
|
|||
|
||||
return matched_trees
|
||||
|
||||
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited):
|
||||
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map):
|
||||
if(root in id_to_position and candidate in id_to_position[root]):
|
||||
# color the node since it is valid
|
||||
isVisited[candidate] = True
|
||||
candidate_children = doc[candidate].children
|
||||
for candidate_child in candidate_children:
|
||||
if root in tree:
|
||||
for root_child in tree[root]:
|
||||
self.dfs(
|
||||
candidate_child.i,
|
||||
root_child,
|
||||
tree,
|
||||
id_to_position,
|
||||
doc,
|
||||
isVisited
|
||||
)
|
||||
if root in tree:
|
||||
for root_child in tree[root]:
|
||||
if candidate in _node_operator_map and root_child[INDEX_RELOP] in _node_operator_map[candidate]:
|
||||
candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]]
|
||||
for candidate_child in candidate_children:
|
||||
result = self.dfs(
|
||||
candidate_child.i,
|
||||
root_child[INDEX_HEAD],
|
||||
tree,
|
||||
id_to_position,
|
||||
doc,
|
||||
isVisited,
|
||||
_node_operator_map
|
||||
)
|
||||
|
||||
# Given a node and an edge operator, to return the list of nodes
|
||||
# from the doc that belong to node+operator. This is used to store
|
||||
# all the results beforehand to prevent unnecessary computation while
|
||||
# pattern matching
|
||||
# _node_operator_map[node][operator] = [...]
|
||||
def get_node_operator_map(self,doc,tree,id_to_position,nodes,root):
|
||||
_node_operator_map = {}
|
||||
all_node_indices = nodes.values()
|
||||
all_operators = []
|
||||
for node in all_node_indices:
|
||||
if node in tree:
|
||||
for child in tree[node]:
|
||||
all_operators.append(child[INDEX_RELOP])
|
||||
all_operators = list(set(all_operators))
|
||||
|
||||
all_nodes = []
|
||||
for node in all_node_indices:
|
||||
all_nodes = all_nodes + id_to_position[node]
|
||||
all_nodes = list(set(all_nodes))
|
||||
|
||||
for node in all_nodes:
|
||||
_node_operator_map[node] = {}
|
||||
for operator in all_operators:
|
||||
_node_operator_map[node][operator] = []
|
||||
|
||||
# Used to invoke methods for each operator
|
||||
switcher = {
|
||||
'<':self.dep,
|
||||
'>':self.gov,
|
||||
'>>':self.dep_chain,
|
||||
'<<':self.gov_chain,
|
||||
'.':self.imm_precede,
|
||||
'$+':self.imm_right_sib,
|
||||
'$-':self.imm_left_sib,
|
||||
'$++':self.right_sib,
|
||||
'$--':self.left_sib
|
||||
}
|
||||
for operator in all_operators:
|
||||
for node in all_nodes:
|
||||
_node_operator_map[node][operator] = switcher.get(operator)(doc,node)
|
||||
|
||||
return _node_operator_map
|
||||
|
||||
def dep(self,doc,node):
|
||||
return list(doc[node].head)
|
||||
|
||||
def gov(self,doc,node):
|
||||
return list(doc[node].children)
|
||||
|
||||
def dep_chain(self,doc,node):
|
||||
return list(doc[node].ancestors)
|
||||
|
||||
def gov_chain(self,doc,node):
|
||||
return list(doc[node].subtree)
|
||||
|
||||
def imm_precede(self,doc,node):
|
||||
if node>0:
|
||||
return [doc[node-1]]
|
||||
return []
|
||||
|
||||
def imm_right_sib(self,doc,node):
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx == node-1:
|
||||
return [doc[idx]]
|
||||
return []
|
||||
|
||||
def imm_left_sib(self,doc,node):
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx == node+1:
|
||||
return [doc[idx]]
|
||||
return []
|
||||
|
||||
def right_sib(self,doc,node):
|
||||
candidate_children = []
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx < node:
|
||||
candidate_children.append(doc[idx])
|
||||
return candidate_children
|
||||
|
||||
def left_sib(self,doc,node):
|
||||
candidate_children = []
|
||||
for idx in range(list(doc[node].head.children)):
|
||||
if idx > node:
|
||||
candidate_children.append(doc[idx])
|
||||
return candidate_children
|
||||
|
||||
def _normalize_key(self, key):
|
||||
if isinstance(key, basestring):
|
||||
|
|
Loading…
Reference in New Issue
Block a user