mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
5d2cb86c34
Fix `DependencyMatcher` so that the callback is called only once per match.
419 lines
16 KiB
Cython
419 lines
16 KiB
Cython
# cython: infer_types=True, profile=True
|
|
from typing import List
|
|
|
|
import numpy
|
|
|
|
from cymem.cymem cimport Pool
|
|
|
|
from .matcher cimport Matcher
|
|
from ..vocab cimport Vocab
|
|
from ..tokens.doc cimport Doc
|
|
|
|
from ..errors import Errors
|
|
from ..tokens import Span
|
|
|
|
|
|
DELIMITER = "||"
|
|
INDEX_HEAD = 1
|
|
INDEX_RELOP = 0
|
|
|
|
|
|
cdef class DependencyMatcher:
|
|
"""Match dependency parse tree based on pattern rules."""
|
|
cdef Pool mem
|
|
cdef readonly Vocab vocab
|
|
cdef readonly Matcher matcher
|
|
cdef public object _patterns
|
|
cdef public object _raw_patterns
|
|
cdef public object _keys_to_token
|
|
cdef public object _root
|
|
cdef public object _callbacks
|
|
cdef public object _nodes
|
|
cdef public object _tree
|
|
cdef public object _ops
|
|
|
|
def __init__(self, vocab, *, validate=False):
|
|
"""Create the DependencyMatcher.
|
|
|
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
|
documents the matcher will operate on.
|
|
validate (bool): Whether patterns should be validated, passed to
|
|
Matcher as `validate`
|
|
"""
|
|
size = 20
|
|
self.matcher = Matcher(vocab, validate=validate)
|
|
self._keys_to_token = {}
|
|
self._patterns = {}
|
|
self._raw_patterns = {}
|
|
self._root = {}
|
|
self._nodes = {}
|
|
self._tree = {}
|
|
self._callbacks = {}
|
|
self.vocab = vocab
|
|
self.mem = Pool()
|
|
self._ops = {
|
|
"<": self.dep,
|
|
">": self.gov,
|
|
"<<": self.dep_chain,
|
|
">>": self.gov_chain,
|
|
".": self.imm_precede,
|
|
".*": self.precede,
|
|
";": self.imm_follow,
|
|
";*": self.follow,
|
|
"$+": self.imm_right_sib,
|
|
"$-": self.imm_left_sib,
|
|
"$++": self.right_sib,
|
|
"$--": self.left_sib,
|
|
}
|
|
|
|
def __reduce__(self):
|
|
data = (self.vocab, self._raw_patterns, self._callbacks)
|
|
return (unpickle_matcher, data, None, None)
|
|
|
|
def __len__(self):
|
|
"""Get the number of rules, which are edges, added to the dependency
|
|
tree matcher.
|
|
|
|
RETURNS (int): The number of rules.
|
|
"""
|
|
return len(self._patterns)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (str): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
"""
|
|
return self.has_key(key)
|
|
|
|
def validate_input(self, pattern, key):
|
|
idx = 0
|
|
visited_nodes = {}
|
|
for relation in pattern:
|
|
if not isinstance(relation, dict):
|
|
raise ValueError(Errors.E1008)
|
|
if "RIGHT_ATTRS" not in relation and "RIGHT_ID" not in relation:
|
|
raise ValueError(Errors.E098.format(key=key))
|
|
if idx == 0:
|
|
if not(
|
|
"RIGHT_ID" in relation
|
|
and "REL_OP" not in relation
|
|
and "LEFT_ID" not in relation
|
|
):
|
|
raise ValueError(Errors.E099.format(key=key))
|
|
visited_nodes[relation["RIGHT_ID"]] = True
|
|
else:
|
|
if not(
|
|
"RIGHT_ID" in relation
|
|
and "RIGHT_ATTRS" in relation
|
|
and "REL_OP" in relation
|
|
and "LEFT_ID" in relation
|
|
):
|
|
raise ValueError(Errors.E100.format(key=key))
|
|
if (
|
|
relation["RIGHT_ID"] in visited_nodes
|
|
or relation["LEFT_ID"] not in visited_nodes
|
|
):
|
|
raise ValueError(Errors.E101.format(key=key))
|
|
if relation["REL_OP"] not in self._ops:
|
|
raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
|
|
visited_nodes[relation["RIGHT_ID"]] = True
|
|
visited_nodes[relation["LEFT_ID"]] = True
|
|
idx = idx + 1
|
|
|
|
def add(self, key, patterns, *, on_match=None):
|
|
"""Add a new matcher rule to the matcher.
|
|
|
|
key (str): The match ID.
|
|
patterns (list): The patterns to add for the given key.
|
|
on_match (callable): Optional callback executed on match.
|
|
"""
|
|
if on_match is not None and not hasattr(on_match, "__call__"):
|
|
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
|
|
if patterns is None or not isinstance(patterns, List): # old API
|
|
raise ValueError(Errors.E948.format(arg_type=type(patterns)))
|
|
for pattern in patterns:
|
|
if len(pattern) == 0:
|
|
raise ValueError(Errors.E012.format(key=key))
|
|
self.validate_input(pattern, key)
|
|
key = self._normalize_key(key)
|
|
self._raw_patterns.setdefault(key, [])
|
|
self._raw_patterns[key].extend(patterns)
|
|
_patterns = []
|
|
for pattern in patterns:
|
|
token_patterns = []
|
|
for i in range(len(pattern)):
|
|
token_pattern = [pattern[i]["RIGHT_ATTRS"]]
|
|
token_patterns.append(token_pattern)
|
|
_patterns.append(token_patterns)
|
|
self._patterns.setdefault(key, [])
|
|
self._callbacks[key] = on_match
|
|
self._patterns[key].extend(_patterns)
|
|
# Add each node pattern of all the input patterns individually to the
|
|
# matcher. This enables only a single instance of Matcher to be used.
|
|
# Multiple adds are required to track each node pattern.
|
|
_keys_to_token_list = []
|
|
for i in range(len(_patterns)):
|
|
_keys_to_token = {}
|
|
# TODO: Better ways to hash edges in pattern?
|
|
for j in range(len(_patterns[i])):
|
|
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
|
|
self.matcher.add(k, [_patterns[i][j]])
|
|
_keys_to_token[k] = j
|
|
_keys_to_token_list.append(_keys_to_token)
|
|
self._keys_to_token.setdefault(key, [])
|
|
self._keys_to_token[key].extend(_keys_to_token_list)
|
|
_nodes_list = []
|
|
for pattern in patterns:
|
|
nodes = {}
|
|
for i in range(len(pattern)):
|
|
nodes[pattern[i]["RIGHT_ID"]] = i
|
|
_nodes_list.append(nodes)
|
|
self._nodes.setdefault(key, [])
|
|
self._nodes[key].extend(_nodes_list)
|
|
# Create an object tree to traverse later on. This data structure
|
|
# enables easy tree pattern match. Doc-Token based tree cannot be
|
|
# reused since it is memory-heavy and tightly coupled with the Doc.
|
|
self.retrieve_tree(patterns, _nodes_list, key)
|
|
|
|
def retrieve_tree(self, patterns, _nodes_list, key):
|
|
_heads_list = []
|
|
_root_list = []
|
|
for i in range(len(patterns)):
|
|
heads = {}
|
|
root = -1
|
|
for j in range(len(patterns[i])):
|
|
token_pattern = patterns[i][j]
|
|
if ("REL_OP" not in token_pattern):
|
|
heads[j] = ('root', j)
|
|
root = j
|
|
else:
|
|
heads[j] = (
|
|
token_pattern["REL_OP"],
|
|
_nodes_list[i][token_pattern["LEFT_ID"]]
|
|
)
|
|
_heads_list.append(heads)
|
|
_root_list.append(root)
|
|
_tree_list = []
|
|
for i in range(len(patterns)):
|
|
tree = {}
|
|
for j in range(len(patterns[i])):
|
|
if(_heads_list[i][j][INDEX_HEAD] == j):
|
|
continue
|
|
head = _heads_list[i][j][INDEX_HEAD]
|
|
if(head not in tree):
|
|
tree[head] = []
|
|
tree[head].append((_heads_list[i][j][INDEX_RELOP], j))
|
|
_tree_list.append(tree)
|
|
self._tree.setdefault(key, [])
|
|
self._tree[key].extend(_tree_list)
|
|
self._root.setdefault(key, [])
|
|
self._root[key].extend(_root_list)
|
|
|
|
def has_key(self, key):
|
|
"""Check whether the matcher has a rule with a given key.
|
|
|
|
key (string or int): The key to check.
|
|
RETURNS (bool): Whether the matcher has the rule.
|
|
"""
|
|
return self._normalize_key(key) in self._patterns
|
|
|
|
def get(self, key, default=None):
|
|
"""Retrieve the pattern stored for a key.
|
|
|
|
key (str / int): The key to retrieve.
|
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
if key not in self._raw_patterns:
|
|
return default
|
|
return (self._callbacks[key], self._raw_patterns[key])
|
|
|
|
def remove(self, key):
|
|
key = self._normalize_key(key)
|
|
if not key in self._patterns:
|
|
raise ValueError(Errors.E175.format(key=key))
|
|
self._patterns.pop(key)
|
|
self._raw_patterns.pop(key)
|
|
self._nodes.pop(key)
|
|
self._tree.pop(key)
|
|
self._root.pop(key)
|
|
|
|
def __call__(self, object doclike):
|
|
"""Find all token sequences matching the supplied pattern.
|
|
|
|
doclike (Doc or Span): The document to match over.
|
|
RETURNS (list): A list of `(key, start, end)` tuples,
|
|
describing the matches. A match tuple describes a span
|
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
|
"""
|
|
if isinstance(doclike, Doc):
|
|
doc = doclike
|
|
elif isinstance(doclike, Span):
|
|
doc = doclike.as_doc()
|
|
else:
|
|
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
|
|
matched_key_trees = []
|
|
matches = self.matcher(doc)
|
|
for key in list(self._patterns.keys()):
|
|
_patterns_list = self._patterns[key]
|
|
_keys_to_token_list = self._keys_to_token[key]
|
|
_root_list = self._root[key]
|
|
_tree_list = self._tree[key]
|
|
_nodes_list = self._nodes[key]
|
|
length = len(_patterns_list)
|
|
for i in range(length):
|
|
_keys_to_token = _keys_to_token_list[i]
|
|
_root = _root_list[i]
|
|
_tree = _tree_list[i]
|
|
_nodes = _nodes_list[i]
|
|
id_to_position = {}
|
|
for i in range(len(_nodes)):
|
|
id_to_position[i]=[]
|
|
# TODO: This could be taken outside to improve running time..?
|
|
for match_id, start, end in matches:
|
|
if match_id in _keys_to_token:
|
|
id_to_position[_keys_to_token[match_id]].append(start)
|
|
_node_operator_map = self.get_node_operator_map(
|
|
doc,
|
|
_tree,
|
|
id_to_position,
|
|
_nodes,_root
|
|
)
|
|
length = len(_nodes)
|
|
|
|
matched_trees = []
|
|
self.recurse(_tree, id_to_position, _node_operator_map, 0, [], matched_trees)
|
|
for matched_tree in matched_trees:
|
|
matched_key_trees.append((key, matched_tree))
|
|
for i, (match_id, nodes) in enumerate(matched_key_trees):
|
|
on_match = self._callbacks.get(match_id)
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matched_key_trees)
|
|
return matched_key_trees
|
|
|
|
def recurse(self, tree, id_to_position, _node_operator_map, int patternLength, visited_nodes, matched_trees):
|
|
cdef bint isValid;
|
|
if patternLength == len(id_to_position.keys()):
|
|
isValid = True
|
|
for node in range(patternLength):
|
|
if node in tree:
|
|
for idx, (relop,nbor) in enumerate(tree[node]):
|
|
computed_nbors = numpy.asarray(_node_operator_map[visited_nodes[node]][relop])
|
|
isNbor = False
|
|
for computed_nbor in computed_nbors:
|
|
if computed_nbor.i == visited_nodes[nbor]:
|
|
isNbor = True
|
|
isValid = isValid & isNbor
|
|
if(isValid):
|
|
matched_trees.append(visited_nodes)
|
|
return
|
|
allPatternNodes = numpy.asarray(id_to_position[patternLength])
|
|
for patternNode in allPatternNodes:
|
|
self.recurse(tree, id_to_position, _node_operator_map, patternLength+1, visited_nodes+[patternNode], matched_trees)
|
|
|
|
# Given a node and an edge operator, to return the list of nodes
|
|
# from the doc that belong to node+operator. This is used to store
|
|
# all the results beforehand to prevent unnecessary computation while
|
|
# pattern matching
|
|
# _node_operator_map[node][operator] = [...]
|
|
def get_node_operator_map(self, doc, tree, id_to_position, nodes, root):
|
|
_node_operator_map = {}
|
|
all_node_indices = nodes.values()
|
|
all_operators = []
|
|
for node in all_node_indices:
|
|
if node in tree:
|
|
for child in tree[node]:
|
|
all_operators.append(child[INDEX_RELOP])
|
|
all_operators = list(set(all_operators))
|
|
all_nodes = []
|
|
for node in all_node_indices:
|
|
all_nodes = all_nodes + id_to_position[node]
|
|
all_nodes = list(set(all_nodes))
|
|
for node in all_nodes:
|
|
_node_operator_map[node] = {}
|
|
for operator in all_operators:
|
|
_node_operator_map[node][operator] = []
|
|
for operator in all_operators:
|
|
for node in all_nodes:
|
|
_node_operator_map[node][operator] = self._ops.get(operator)(doc, node)
|
|
return _node_operator_map
|
|
|
|
def dep(self, doc, node):
|
|
if doc[node].head == doc[node]:
|
|
return []
|
|
return [doc[node].head]
|
|
|
|
def gov(self,doc,node):
|
|
return list(doc[node].children)
|
|
|
|
def dep_chain(self, doc, node):
|
|
return list(doc[node].ancestors)
|
|
|
|
def gov_chain(self, doc, node):
|
|
return [t for t in doc[node].subtree if t != doc[node]]
|
|
|
|
def imm_precede(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
if node < len(doc) - 1 and doc[node + 1] in sent:
|
|
return [doc[node + 1]]
|
|
return []
|
|
|
|
def precede(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
return [doc[i] for i in range(node + 1, sent.end)]
|
|
|
|
def imm_follow(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
if node > 0 and doc[node - 1] in sent:
|
|
return [doc[node - 1]]
|
|
return []
|
|
|
|
def follow(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
return [doc[i] for i in range(sent.start, node)]
|
|
|
|
def imm_right_sib(self, doc, node):
|
|
for child in list(doc[node].head.children):
|
|
if child.i == node + 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def imm_left_sib(self, doc, node):
|
|
for child in list(doc[node].head.children):
|
|
if child.i == node - 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def right_sib(self, doc, node):
|
|
candidate_children = []
|
|
for child in list(doc[node].head.children):
|
|
if child.i > node:
|
|
candidate_children.append(doc[child.i])
|
|
return candidate_children
|
|
|
|
def left_sib(self, doc, node):
|
|
candidate_children = []
|
|
for child in list(doc[node].head.children):
|
|
if child.i < node:
|
|
candidate_children.append(doc[child.i])
|
|
return candidate_children
|
|
|
|
def _normalize_key(self, key):
|
|
if isinstance(key, basestring):
|
|
return self.vocab.strings.add(key)
|
|
else:
|
|
return key
|
|
|
|
def _get_sent(self, token):
|
|
root = (list(token.ancestors) or [token])[-1]
|
|
return token.doc[root.left_edge.i:root.right_edge.i + 1]
|
|
|
|
|
|
def unpickle_matcher(vocab, patterns, callbacks):
|
|
matcher = DependencyMatcher(vocab)
|
|
for key, pattern in patterns.items():
|
|
callback = callbacks.get(key, None)
|
|
matcher.add(key, pattern, on_match=callback)
|
|
return matcher
|