spaCy/spacy/matcher/dependencymatcher.pyx
Adriane Boyd 31de700b0f
Fix on_match callback and remove empty patterns (#6312)
For the `DependencyMatcher`:

* Fix on_match callback so that it is called once per matched pattern
* Fix results so that patterns with empty match lists are not returned
2020-11-05 09:16:26 +01:00

419 lines
16 KiB
Cython

# cython: infer_types=True, profile=True
from typing import List
import numpy
from cymem.cymem cimport Pool
from .matcher cimport Matcher
from ..vocab cimport Vocab
from ..tokens.doc cimport Doc
from ..errors import Errors
from ..tokens import Span
DELIMITER = "||"
INDEX_HEAD = 1
INDEX_RELOP = 0
cdef class DependencyMatcher:
"""Match dependency parse tree based on pattern rules."""
cdef Pool mem
cdef readonly Vocab vocab
cdef readonly Matcher matcher
cdef public object _patterns
cdef public object _raw_patterns
cdef public object _keys_to_token
cdef public object _root
cdef public object _callbacks
cdef public object _nodes
cdef public object _tree
cdef public object _ops
def __init__(self, vocab, *, validate=False):
"""Create the DependencyMatcher.
vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on.
validate (bool): Whether patterns should be validated, passed to
Matcher as `validate`
"""
size = 20
self.matcher = Matcher(vocab, validate=validate)
self._keys_to_token = {}
self._patterns = {}
self._raw_patterns = {}
self._root = {}
self._nodes = {}
self._tree = {}
self._callbacks = {}
self.vocab = vocab
self.mem = Pool()
self._ops = {
"<": self.dep,
">": self.gov,
"<<": self.dep_chain,
">>": self.gov_chain,
".": self.imm_precede,
".*": self.precede,
";": self.imm_follow,
";*": self.follow,
"$+": self.imm_right_sib,
"$-": self.imm_left_sib,
"$++": self.right_sib,
"$--": self.left_sib,
}
def __reduce__(self):
data = (self.vocab, self._raw_patterns, self._callbacks)
return (unpickle_matcher, data, None, None)
def __len__(self):
"""Get the number of rules, which are edges, added to the dependency
tree matcher.
RETURNS (int): The number of rules.
"""
return len(self._patterns)
def __contains__(self, key):
"""Check whether the matcher contains rules for a match ID.
key (str): The match ID.
RETURNS (bool): Whether the matcher contains rules for this match ID.
"""
return self.has_key(key)
def validate_input(self, pattern, key):
idx = 0
visited_nodes = {}
for relation in pattern:
if not isinstance(relation, dict):
raise ValueError(Errors.E1008)
if "RIGHT_ATTRS" not in relation and "RIGHT_ID" not in relation:
raise ValueError(Errors.E098.format(key=key))
if idx == 0:
if not(
"RIGHT_ID" in relation
and "REL_OP" not in relation
and "LEFT_ID" not in relation
):
raise ValueError(Errors.E099.format(key=key))
visited_nodes[relation["RIGHT_ID"]] = True
else:
if not(
"RIGHT_ID" in relation
and "RIGHT_ATTRS" in relation
and "REL_OP" in relation
and "LEFT_ID" in relation
):
raise ValueError(Errors.E100.format(key=key))
if (
relation["RIGHT_ID"] in visited_nodes
or relation["LEFT_ID"] not in visited_nodes
):
raise ValueError(Errors.E101.format(key=key))
if relation["REL_OP"] not in self._ops:
raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
visited_nodes[relation["RIGHT_ID"]] = True
visited_nodes[relation["LEFT_ID"]] = True
idx = idx + 1
def add(self, key, patterns, *, on_match=None):
"""Add a new matcher rule to the matcher.
key (str): The match ID.
patterns (list): The patterns to add for the given key.
on_match (callable): Optional callback executed on match.
"""
if on_match is not None and not hasattr(on_match, "__call__"):
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
if patterns is None or not isinstance(patterns, List): # old API
raise ValueError(Errors.E948.format(arg_type=type(patterns)))
for pattern in patterns:
if len(pattern) == 0:
raise ValueError(Errors.E012.format(key=key))
self.validate_input(pattern, key)
key = self._normalize_key(key)
self._raw_patterns.setdefault(key, [])
self._raw_patterns[key].extend(patterns)
_patterns = []
for pattern in patterns:
token_patterns = []
for i in range(len(pattern)):
token_pattern = [pattern[i]["RIGHT_ATTRS"]]
token_patterns.append(token_pattern)
_patterns.append(token_patterns)
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
self._patterns[key].extend(_patterns)
# Add each node pattern of all the input patterns individually to the
# matcher. This enables only a single instance of Matcher to be used.
# Multiple adds are required to track each node pattern.
_keys_to_token_list = []
for i in range(len(_patterns)):
_keys_to_token = {}
# TODO: Better ways to hash edges in pattern?
for j in range(len(_patterns[i])):
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
self.matcher.add(k, [_patterns[i][j]])
_keys_to_token[k] = j
_keys_to_token_list.append(_keys_to_token)
self._keys_to_token.setdefault(key, [])
self._keys_to_token[key].extend(_keys_to_token_list)
_nodes_list = []
for pattern in patterns:
nodes = {}
for i in range(len(pattern)):
nodes[pattern[i]["RIGHT_ID"]] = i
_nodes_list.append(nodes)
self._nodes.setdefault(key, [])
self._nodes[key].extend(_nodes_list)
# Create an object tree to traverse later on. This data structure
# enables easy tree pattern match. Doc-Token based tree cannot be
# reused since it is memory-heavy and tightly coupled with the Doc.
self.retrieve_tree(patterns, _nodes_list, key)
def retrieve_tree(self, patterns, _nodes_list, key):
_heads_list = []
_root_list = []
for i in range(len(patterns)):
heads = {}
root = -1
for j in range(len(patterns[i])):
token_pattern = patterns[i][j]
if ("REL_OP" not in token_pattern):
heads[j] = ('root', j)
root = j
else:
heads[j] = (
token_pattern["REL_OP"],
_nodes_list[i][token_pattern["LEFT_ID"]]
)
_heads_list.append(heads)
_root_list.append(root)
_tree_list = []
for i in range(len(patterns)):
tree = {}
for j in range(len(patterns[i])):
if(_heads_list[i][j][INDEX_HEAD] == j):
continue
head = _heads_list[i][j][INDEX_HEAD]
if(head not in tree):
tree[head] = []
tree[head].append((_heads_list[i][j][INDEX_RELOP], j))
_tree_list.append(tree)
self._tree.setdefault(key, [])
self._tree[key].extend(_tree_list)
self._root.setdefault(key, [])
self._root[key].extend(_root_list)
def has_key(self, key):
"""Check whether the matcher has a rule with a given key.
key (string or int): The key to check.
RETURNS (bool): Whether the matcher has the rule.
"""
return self._normalize_key(key) in self._patterns
def get(self, key, default=None):
"""Retrieve the pattern stored for a key.
key (str / int): The key to retrieve.
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
"""
key = self._normalize_key(key)
if key not in self._raw_patterns:
return default
return (self._callbacks[key], self._raw_patterns[key])
def remove(self, key):
key = self._normalize_key(key)
if not key in self._patterns:
raise ValueError(Errors.E175.format(key=key))
self._patterns.pop(key)
self._raw_patterns.pop(key)
self._nodes.pop(key)
self._tree.pop(key)
self._root.pop(key)
def __call__(self, object doclike):
"""Find all token sequences matching the supplied pattern.
doclike (Doc or Span): The document to match over.
RETURNS (list): A list of `(key, start, end)` tuples,
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
if isinstance(doclike, Doc):
doc = doclike
elif isinstance(doclike, Span):
doc = doclike.as_doc()
else:
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
matched_key_trees = []
matches = self.matcher(doc)
for key in list(self._patterns.keys()):
_patterns_list = self._patterns[key]
_keys_to_token_list = self._keys_to_token[key]
_root_list = self._root[key]
_tree_list = self._tree[key]
_nodes_list = self._nodes[key]
length = len(_patterns_list)
for i in range(length):
_keys_to_token = _keys_to_token_list[i]
_root = _root_list[i]
_tree = _tree_list[i]
_nodes = _nodes_list[i]
id_to_position = {}
for i in range(len(_nodes)):
id_to_position[i]=[]
# TODO: This could be taken outside to improve running time..?
for match_id, start, end in matches:
if match_id in _keys_to_token:
id_to_position[_keys_to_token[match_id]].append(start)
_node_operator_map = self.get_node_operator_map(
doc,
_tree,
id_to_position,
_nodes,_root
)
length = len(_nodes)
matched_trees = []
self.recurse(_tree, id_to_position, _node_operator_map, 0, [], matched_trees)
for matched_tree in matched_trees:
matched_key_trees.append((key, matched_tree))
for i, (match_id, nodes) in enumerate(matched_key_trees):
on_match = self._callbacks.get(match_id)
if on_match is not None:
on_match(self, doc, i, matched_key_trees)
return matched_key_trees
def recurse(self, tree, id_to_position, _node_operator_map, int patternLength, visited_nodes, matched_trees):
cdef bint isValid;
if patternLength == len(id_to_position.keys()):
isValid = True
for node in range(patternLength):
if node in tree:
for idx, (relop,nbor) in enumerate(tree[node]):
computed_nbors = numpy.asarray(_node_operator_map[visited_nodes[node]][relop])
isNbor = False
for computed_nbor in computed_nbors:
if computed_nbor.i == visited_nodes[nbor]:
isNbor = True
isValid = isValid & isNbor
if(isValid):
matched_trees.append(visited_nodes)
return
allPatternNodes = numpy.asarray(id_to_position[patternLength])
for patternNode in allPatternNodes:
self.recurse(tree, id_to_position, _node_operator_map, patternLength+1, visited_nodes+[patternNode], matched_trees)
# Given a node and an edge operator, to return the list of nodes
# from the doc that belong to node+operator. This is used to store
# all the results beforehand to prevent unnecessary computation while
# pattern matching
# _node_operator_map[node][operator] = [...]
def get_node_operator_map(self, doc, tree, id_to_position, nodes, root):
_node_operator_map = {}
all_node_indices = nodes.values()
all_operators = []
for node in all_node_indices:
if node in tree:
for child in tree[node]:
all_operators.append(child[INDEX_RELOP])
all_operators = list(set(all_operators))
all_nodes = []
for node in all_node_indices:
all_nodes = all_nodes + id_to_position[node]
all_nodes = list(set(all_nodes))
for node in all_nodes:
_node_operator_map[node] = {}
for operator in all_operators:
_node_operator_map[node][operator] = []
for operator in all_operators:
for node in all_nodes:
_node_operator_map[node][operator] = self._ops.get(operator)(doc, node)
return _node_operator_map
def dep(self, doc, node):
if doc[node].head == doc[node]:
return []
return [doc[node].head]
def gov(self,doc,node):
return list(doc[node].children)
def dep_chain(self, doc, node):
return list(doc[node].ancestors)
def gov_chain(self, doc, node):
return [t for t in doc[node].subtree if t != doc[node]]
def imm_precede(self, doc, node):
sent = self._get_sent(doc[node])
if node < len(doc) - 1 and doc[node + 1] in sent:
return [doc[node + 1]]
return []
def precede(self, doc, node):
sent = self._get_sent(doc[node])
return [doc[i] for i in range(node + 1, sent.end)]
def imm_follow(self, doc, node):
sent = self._get_sent(doc[node])
if node > 0 and doc[node - 1] in sent:
return [doc[node - 1]]
return []
def follow(self, doc, node):
sent = self._get_sent(doc[node])
return [doc[i] for i in range(sent.start, node)]
def imm_right_sib(self, doc, node):
for child in list(doc[node].head.children):
if child.i == node + 1:
return [doc[child.i]]
return []
def imm_left_sib(self, doc, node):
for child in list(doc[node].head.children):
if child.i == node - 1:
return [doc[child.i]]
return []
def right_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
if child.i > node:
candidate_children.append(doc[child.i])
return candidate_children
def left_sib(self, doc, node):
candidate_children = []
for child in list(doc[node].head.children):
if child.i < node:
candidate_children.append(doc[child.i])
return candidate_children
def _normalize_key(self, key):
if isinstance(key, basestring):
return self.vocab.strings.add(key)
else:
return key
def _get_sent(self, token):
root = (list(token.ancestors) or [token])[-1]
return token.doc[root.left_edge.i:root.right_edge.i + 1]
def unpickle_matcher(vocab, patterns, callbacks):
matcher = DependencyMatcher(vocab)
for key, pattern in patterns.items():
callback = callbacks.get(key, None)
matcher.add(key, pattern, on_match=callback)
return matcher