spaCy/spacy/matcher/dependencymatcher.pyx
2020-07-29 11:36:42 +02:00

351 lines
13 KiB
Cython

# cython: infer_types=True, profile=True
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
from libcpp cimport bool
import numpy
from .matcher cimport Matcher
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc
from .matcher import unpickle_matcher
from ..errors import Errors
DELIMITER = "||"
INDEX_HEAD = 1
INDEX_RELOP = 0
cdef class DependencyMatcher:
"""Match dependency parse tree based on pattern rules."""
cdef Pool mem
cdef readonly Vocab vocab
cdef readonly Matcher token_matcher
cdef public object _patterns
cdef public object _keys_to_token
cdef public object _root
cdef public object _entities
cdef public object _callbacks
cdef public object _nodes
cdef public object _tree
def __init__(self, vocab):
"""Create the DependencyMatcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
"""
size = 20
# TODO: make matcher work with validation
self.token_matcher = Matcher(vocab, validate=False)
self._keys_to_token = {}
self._patterns = {}
self._root = {}
self._nodes = {}
self._tree = {}
self._entities = {}
self._callbacks = {}
self.vocab = vocab
self.mem = Pool()
def __reduce__(self):
data = (self.vocab, self._patterns,self._tree, self._callbacks)
return (unpickle_matcher, data, None, None)
def __len__(self):
"""Get the number of rules, which are edges, added to the dependency
tree matcher.
RETURNS (int): The number of rules.
"""
return len(self._patterns)
def __contains__(self, key):
"""Check whether the matcher contains rules for a match ID.
key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
"""
return self._normalize_key(key) in self._patterns
def validateInput(self, pattern, key):
idx = 0
visitedNodes = {}
for relation in pattern:
if "PATTERN" not in relation or "SPEC" not in relation:
raise ValueError(Errors.E098.format(key=key))
if idx == 0:
if not(
"NODE_NAME" in relation["SPEC"]
and "NBOR_RELOP" not in relation["SPEC"]
and "NBOR_NAME" not in relation["SPEC"]
):
raise ValueError(Errors.E099.format(key=key))
visitedNodes[relation["SPEC"]["NODE_NAME"]] = True
else:
if not(
"NODE_NAME" in relation["SPEC"]
and "NBOR_RELOP" in relation["SPEC"]
and "NBOR_NAME" in relation["SPEC"]
):
raise ValueError(Errors.E100.format(key=key))
if (
relation["SPEC"]["NODE_NAME"] in visitedNodes
or relation["SPEC"]["NBOR_NAME"] not in visitedNodes
):
raise ValueError(Errors.E101.format(key=key))
visitedNodes[relation["SPEC"]["NODE_NAME"]] = True
visitedNodes[relation["SPEC"]["NBOR_NAME"]] = True
idx = idx + 1
def add(self, key, patterns, *_patterns, on_match=None):
if patterns is None or hasattr(patterns, "__call__"): # old API
on_match = patterns
patterns = _patterns
for pattern in patterns:
if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key))
self.validateInput(pattern,key)
key = self._normalize_key(key)
_patterns = []
for pattern in patterns:
token_patterns = []
for i in range(len(pattern)):
token_pattern = [pattern[i]["PATTERN"]]
token_patterns.append(token_pattern)
# self.patterns.append(token_patterns)
_patterns.append(token_patterns)
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
self._patterns[key].extend(_patterns)
# Add each node pattern of all the input patterns individually to the
# matcher. This enables only a single instance of Matcher to be used.
# Multiple adds are required to track each node pattern.
_keys_to_token_list = []
for i in range(len(_patterns)):
_keys_to_token = {}
# TODO: Better ways to hash edges in pattern?
for j in range(len(_patterns[i])):
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
self.token_matcher.add(k, [_patterns[i][j]])
_keys_to_token[k] = j
_keys_to_token_list.append(_keys_to_token)
self._keys_to_token.setdefault(key, [])
self._keys_to_token[key].extend(_keys_to_token_list)
_nodes_list = []
for pattern in patterns:
nodes = {}
for i in range(len(pattern)):
nodes[pattern[i]["SPEC"]["NODE_NAME"]] = i
_nodes_list.append(nodes)
self._nodes.setdefault(key, [])
self._nodes[key].extend(_nodes_list)
# Create an object tree to traverse later on. This data structure
# enables easy tree pattern match. Doc-Token based tree cannot be
# reused since it is memory-heavy and tightly coupled with the Doc.
self.retrieve_tree(patterns, _nodes_list,key)
def retrieve_tree(self, patterns, _nodes_list, key):
_heads_list = []
_root_list = []
for i in range(len(patterns)):
heads = {}
root = -1
for j in range(len(patterns[i])):
token_pattern = patterns[i][j]
if ("NBOR_RELOP" not in token_pattern["SPEC"]):
heads[j] = ('root', j)
root = j
else:
heads[j] = (
token_pattern["SPEC"]["NBOR_RELOP"],
_nodes_list[i][token_pattern["SPEC"]["NBOR_NAME"]]
)
_heads_list.append(heads)
_root_list.append(root)
_tree_list = []
for i in range(len(patterns)):
tree = {}
for j in range(len(patterns[i])):
if(_heads_list[i][j][INDEX_HEAD] == j):
continue
head = _heads_list[i][j][INDEX_HEAD]
if(head not in tree):
tree[head] = []
tree[head].append((_heads_list[i][j][INDEX_RELOP], j))
_tree_list.append(tree)
self._tree.setdefault(key, [])
self._tree[key].extend(_tree_list)
self._root.setdefault(key, [])
self._root[key].extend(_root_list)
def has_key(self, key):
"""Check whether the matcher has a rule with a given key.
key (string or int): The key to check.
RETURNS (bool): Whether the matcher has the rule.
"""
key = self._normalize_key(key)
return key in self._patterns
def get(self, key, default=None):
"""Retrieve the pattern stored for a key.
key (str / int): The key to retrieve.
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
"""
key = self._normalize_key(key)
if key not in self._patterns:
return default
return (self._callbacks[key], self._patterns[key])
def __call__(self, Doc doc):
matched_key_trees = []
matches = self.token_matcher(doc)
for key in list(self._patterns.keys()):
_patterns_list = self._patterns[key]
_keys_to_token_list = self._keys_to_token[key]
_root_list = self._root[key]
_tree_list = self._tree[key]
_nodes_list = self._nodes[key]
length = len(_patterns_list)
for i in range(length):
_keys_to_token = _keys_to_token_list[i]
_root = _root_list[i]
_tree = _tree_list[i]
_nodes = _nodes_list[i]
id_to_position = {}
for i in range(len(_nodes)):
id_to_position[i]=[]
# TODO: This could be taken outside to improve running time..?
for match_id, start, end in matches:
if match_id in _keys_to_token:
id_to_position[_keys_to_token[match_id]].append(start)
_node_operator_map = self.get_node_operator_map(
doc,
_tree,
id_to_position,
_nodes,_root
)
length = len(_nodes)
matched_trees = []
self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
matched_key_trees.append((key,matched_trees))
for i, (ent_id, nodes) in enumerate(matched_key_trees):
on_match = self._callbacks.get(ent_id)
if on_match is not None:
on_match(self, doc, i, matched_key_trees)
return matched_key_trees
def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees):
cdef bool isValid;
if(patternLength == len(id_to_position.keys())):
isValid = True
for node in range(patternLength):
if(node in tree):
for idx, (relop,nbor) in enumerate(tree[node]):
computed_nbors = numpy.asarray(_node_operator_map[visitedNodes[node]][relop])
isNbor = False
for computed_nbor in computed_nbors:
if(computed_nbor.i == visitedNodes[nbor]):
isNbor = True
isValid = isValid & isNbor
if(isValid):
matched_trees.append(visitedNodes)
return
allPatternNodes = numpy.asarray(id_to_position[patternLength])
for patternNode in allPatternNodes:
self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visitedNodes+[patternNode],matched_trees)
# Given a node and an edge operator, to return the list of nodes
# from the doc that belong to node+operator. This is used to store
# all the results beforehand to prevent unnecessary computation while
# pattern matching
# _node_operator_map[node][operator] = [...]
def get_node_operator_map(self,doc,tree,id_to_position,nodes,root):
_node_operator_map = {}
all_node_indices = nodes.values()
all_operators = []
for node in all_node_indices:
if node in tree:
for child in tree[node]:
all_operators.append(child[INDEX_RELOP])
all_operators = list(set(all_operators))
all_nodes = []
for node in all_node_indices:
all_nodes = all_nodes + id_to_position[node]
all_nodes = list(set(all_nodes))
for node in all_nodes:
_node_operator_map[node] = {}
for operator in all_operators:
_node_operator_map[node][operator] = []
# Used to invoke methods for each operator
switcher = {
"<": self.dep,
">": self.gov,
"<<": self.dep_chain,
">>": self.gov_chain,
".": self.imm_precede,
"$+": self.imm_right_sib,
"$-": self.imm_left_sib,
"$++": self.right_sib,
"$--": self.left_sib
}
for operator in all_operators:
for node in all_nodes:
_node_operator_map[node][operator] = switcher.get(operator)(doc,node)
return _node_operator_map
def dep(self, doc, node):
return [doc[node].head]
def gov(self,doc,node):
return list(doc[node].children)
def dep_chain(self, doc, node):
return list(doc[node].ancestors)
def gov_chain(self, doc, node):
return list(doc[node].subtree)
def imm_precede(self, doc, node):
if node > 0:
return [doc[node - 1]]
return []
def imm_right_sib(self, doc, node):
for child in list(doc[node].head.children):
if child.i == node - 1:
return [doc[child.i]]
return []
def imm_left_sib(self, doc, node):
for child in list(doc[node].head.children):
if child.i == node + 1:
return [doc[child.i]]
return []
def right_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
if child.i < node:
candidate_children.append(doc[child.i])
return candidate_children
def left_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
if child.i > node:
candidate_children.append(doc[child.i])
return candidate_children
def _normalize_key(self, key):
if isinstance(key, basestring):
return self.vocab.strings.add(key)
else:
return key