mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-03 22:06:37 +03:00
495 lines
18 KiB
Cython
495 lines
18 KiB
Cython
# cython: infer_types=True
|
|
import warnings
|
|
from collections import defaultdict
|
|
from itertools import product
|
|
from typing import List
|
|
|
|
from ..tokens.doc cimport Doc
|
|
from ..vocab cimport Vocab
|
|
from .matcher cimport Matcher
|
|
|
|
from ..errors import Errors, Warnings
|
|
from ..tokens import Span
|
|
|
|
DELIMITER = "||"
|
|
INDEX_HEAD = 1
|
|
INDEX_RELOP = 0
|
|
|
|
|
|
cdef class DependencyMatcher:
|
|
"""Match dependency parse tree based on pattern rules."""
|
|
cdef readonly Vocab vocab
|
|
cdef readonly Matcher _matcher
|
|
cdef public object _patterns
|
|
cdef public object _raw_patterns
|
|
cdef public object _tokens_to_key
|
|
cdef public object _root
|
|
cdef public object _callbacks
|
|
cdef public object _tree
|
|
cdef public object _ops
|
|
|
|
def __init__(self, vocab, *, validate=False):
|
|
"""Create the DependencyMatcher.
|
|
|
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
|
documents the matcher will operate on.
|
|
validate (bool): Whether patterns should be validated, passed to
|
|
Matcher as `validate`
|
|
"""
|
|
self._matcher = Matcher(vocab, validate=validate)
|
|
|
|
# Associates each key to the raw patterns list added by the user
|
|
# e.g. {'key': [[{'RIGHT_ID': ..., 'RIGHT_ATTRS': ... }, ... ], ... ], ... }
|
|
self._raw_patterns = defaultdict(list)
|
|
|
|
# Associates each key to a list of lists of 'RIGHT_ATTRS'
|
|
# e.g. {'key': [[{'POS': ... }, ... ], ... ], ... }
|
|
self._patterns = defaultdict(list)
|
|
|
|
# Associates each key to a list of lists where each list associates
|
|
# a token position within the pattern to the key used by the internal matcher)
|
|
# e.g. {'key': [ ['token_key', ... ], ... ], ... }
|
|
self._tokens_to_key = defaultdict(list)
|
|
|
|
# Associates each key to a list of ints where each int corresponds
|
|
# to the position of the root token within the pattern
|
|
# e.g. {'key': [3, 1, 4, ... ], ... }
|
|
self._root = defaultdict(list)
|
|
|
|
# Associates each key to a list of dicts where each dict describes all
|
|
# branches from a token (identifed by its position) to other tokens as
|
|
# a list of tuples: ('REL_OP', 'LEFT_ID')
|
|
# e.g. {'key': [{2: [('<', 'left_id'), ...] ...}, ... ], ...}
|
|
self._tree = defaultdict(list)
|
|
|
|
# Associates each key to its on_match callback
|
|
# e.g. {'key': on_match_callback, ...}
|
|
self._callbacks = {}
|
|
|
|
self.vocab = vocab
|
|
self._ops = {
|
|
"<": self._dep,
|
|
">": self._gov,
|
|
"<<": self._dep_chain,
|
|
">>": self._gov_chain,
|
|
".": self._imm_precede,
|
|
".*": self._precede,
|
|
";": self._imm_follow,
|
|
";*": self._follow,
|
|
"$+": self._imm_right_sib,
|
|
"$-": self._imm_left_sib,
|
|
"$++": self._right_sib,
|
|
"$--": self._left_sib,
|
|
">+": self._imm_right_child,
|
|
">-": self._imm_left_child,
|
|
">++": self._right_child,
|
|
">--": self._left_child,
|
|
"<+": self._imm_right_parent,
|
|
"<-": self._imm_left_parent,
|
|
"<++": self._right_parent,
|
|
"<--": self._left_parent,
|
|
}
|
|
|
|
def __reduce__(self):
|
|
data = (self.vocab, self._raw_patterns, self._callbacks)
|
|
return (unpickle_matcher, data, None, None)
|
|
|
|
def __len__(self):
|
|
"""Get the number of rules, which are edges, added to the dependency
|
|
tree matcher.
|
|
|
|
RETURNS (int): The number of rules.
|
|
"""
|
|
return len(self._patterns)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (str): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
"""
|
|
return self.has_key(key) # no-cython-lint: W601
|
|
|
|
def _validate_input(self, pattern, key):
|
|
idx = 0
|
|
visited_nodes = {}
|
|
for relation in pattern:
|
|
if not isinstance(relation, dict):
|
|
raise ValueError(Errors.E1008)
|
|
if "RIGHT_ATTRS" not in relation and "RIGHT_ID" not in relation:
|
|
raise ValueError(Errors.E098.format(key=key))
|
|
if idx == 0:
|
|
if not(
|
|
"RIGHT_ID" in relation
|
|
and "REL_OP" not in relation
|
|
and "LEFT_ID" not in relation
|
|
):
|
|
raise ValueError(Errors.E099.format(key=key))
|
|
visited_nodes[relation["RIGHT_ID"]] = True
|
|
else:
|
|
required_keys = {"RIGHT_ID", "RIGHT_ATTRS", "REL_OP", "LEFT_ID"}
|
|
relation_keys = set(relation.keys())
|
|
# Identify required keys that have not been specified
|
|
missing = required_keys - relation_keys
|
|
if missing:
|
|
missing_txt = ", ".join(list(missing))
|
|
raise ValueError(Errors.E100.format(
|
|
required=required_keys,
|
|
missing=missing_txt
|
|
))
|
|
# Identify additional, unsupported keys
|
|
unsupported = relation_keys - required_keys
|
|
if unsupported:
|
|
unsupported_txt = ", ".join(list(unsupported))
|
|
warnings.warn(Warnings.W126.format(
|
|
unsupported=unsupported_txt
|
|
))
|
|
if (
|
|
relation["RIGHT_ID"] in visited_nodes
|
|
or relation["LEFT_ID"] not in visited_nodes
|
|
):
|
|
raise ValueError(Errors.E101.format(key=key))
|
|
if relation["REL_OP"] not in self._ops:
|
|
raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
|
|
visited_nodes[relation["RIGHT_ID"]] = True
|
|
visited_nodes[relation["LEFT_ID"]] = True
|
|
if relation["RIGHT_ATTRS"].get("OP", "") in ("?", "*", "+"):
|
|
raise ValueError(Errors.E1016.format(node=relation))
|
|
idx = idx + 1
|
|
|
|
def _get_matcher_key(self, key, pattern_idx, token_idx):
|
|
"""
|
|
Creates a token key to be used by the matcher
|
|
"""
|
|
return self._normalize_key(
|
|
str(key) + DELIMITER +
|
|
str(pattern_idx) + DELIMITER +
|
|
str(token_idx)
|
|
)
|
|
|
|
def add(self, key, patterns, *, on_match=None):
|
|
"""Add a new matcher rule to the matcher.
|
|
|
|
key (str): The match ID.
|
|
patterns (list): The patterns to add for the given key.
|
|
on_match (callable): Optional callback executed on match.
|
|
"""
|
|
if on_match is not None and not hasattr(on_match, "__call__"):
|
|
raise ValueError(Errors.E171.format(name="DependencyMatcher", arg_type=type(on_match)))
|
|
if patterns is None or not isinstance(patterns, List):
|
|
raise ValueError(Errors.E948.format(name="DependencyMatcher", arg_type=type(patterns)))
|
|
for pattern in patterns:
|
|
if len(pattern) == 0:
|
|
raise ValueError(Errors.E012.format(key=key))
|
|
self._validate_input(pattern, key)
|
|
|
|
key = self._normalize_key(key)
|
|
|
|
# Save raw patterns and on_match callback
|
|
self._raw_patterns[key].extend(patterns)
|
|
self._callbacks[key] = on_match
|
|
|
|
# Add 'RIGHT_ATTRS' to self._patterns[key]
|
|
_patterns = [[[pat["RIGHT_ATTRS"]] for pat in pattern] for pattern in patterns]
|
|
pattern_offset = len(self._patterns[key])
|
|
self._patterns[key].extend(_patterns)
|
|
|
|
# Add each node pattern of all the input patterns individually to the
|
|
# matcher. This enables only a single instance of Matcher to be used.
|
|
# Multiple adds are required to track each node pattern.
|
|
tokens_to_key_list = []
|
|
for i, current_patterns in enumerate(_patterns, start=pattern_offset):
|
|
|
|
# Preallocate list space
|
|
tokens_to_key = [None] * len(current_patterns)
|
|
|
|
# TODO: Better ways to hash edges in pattern?
|
|
for j, _pattern in enumerate(current_patterns):
|
|
k = self._get_matcher_key(key, i, j)
|
|
self._matcher.add(k, [_pattern])
|
|
tokens_to_key[j] = k
|
|
|
|
tokens_to_key_list.append(tokens_to_key)
|
|
|
|
self._tokens_to_key[key].extend(tokens_to_key_list)
|
|
|
|
# Create an object tree to traverse later on. This data structure
|
|
# enables easy tree pattern match. Doc-Token based tree cannot be
|
|
# reused since it is memory-heavy and tightly coupled with the Doc.
|
|
self._retrieve_tree(patterns, key)
|
|
|
|
def _retrieve_tree(self, patterns, key):
|
|
|
|
# New trees belonging to this key
|
|
tree_list = []
|
|
|
|
# List of indices to the root nodes
|
|
root_list = []
|
|
|
|
# Group token id to create trees
|
|
for i, pattern in enumerate(patterns):
|
|
tree = defaultdict(list)
|
|
root = -1
|
|
|
|
right_id_to_token = {}
|
|
for j, token in enumerate(pattern):
|
|
right_id_to_token[token["RIGHT_ID"]] = j
|
|
|
|
for j, token in enumerate(pattern):
|
|
if "REL_OP" in token:
|
|
# Add tree branch
|
|
tree[right_id_to_token[token["LEFT_ID"]]].append(
|
|
(token["REL_OP"], j),
|
|
)
|
|
else:
|
|
# No 'REL_OP', this is the root
|
|
root = j
|
|
|
|
tree_list.append(tree)
|
|
root_list.append(root)
|
|
|
|
self._tree[key].extend(tree_list)
|
|
self._root[key].extend(root_list)
|
|
|
|
def has_key(self, key):
|
|
"""Check whether the matcher has a rule with a given key.
|
|
|
|
key (string or int): The key to check.
|
|
RETURNS (bool): Whether the matcher has the rule.
|
|
"""
|
|
return self._normalize_key(key) in self._patterns
|
|
|
|
def get(self, key, default=None):
|
|
"""Retrieve the pattern stored for a key.
|
|
|
|
key (str / int): The key to retrieve.
|
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
if key not in self._raw_patterns:
|
|
return default
|
|
return (self._callbacks[key], self._raw_patterns[key])
|
|
|
|
def remove(self, key):
|
|
key = self._normalize_key(key)
|
|
if key not in self._patterns:
|
|
raise ValueError(Errors.E175.format(key=key))
|
|
self._patterns.pop(key)
|
|
self._raw_patterns.pop(key)
|
|
self._tree.pop(key)
|
|
self._root.pop(key)
|
|
for mklist in self._tokens_to_key.pop(key):
|
|
for mkey in mklist:
|
|
self._matcher.remove(mkey)
|
|
|
|
def _get_keys_to_position_maps(self, doc):
|
|
"""
|
|
Processes the doc and groups all matches by their root and match id.
|
|
Returns a dict mapping each (root, match id) pair to the list of
|
|
tokens indices which are descendants of root and match the token
|
|
pattern identified by the given match id.
|
|
|
|
e.g. keys_to_position_maps[root_index][match_id] = [...]
|
|
"""
|
|
keys_to_position_maps = defaultdict(lambda: defaultdict(list))
|
|
for match_id, start, end in self._matcher(doc):
|
|
if start + 1 != end:
|
|
warnings.warn(Warnings.W110.format(tokens=[t.text for t in doc[start:end]], pattern=self._matcher.get(match_id)[1][0][0]))
|
|
token = doc[start]
|
|
root = ([token] + list(token.ancestors))[-1]
|
|
keys_to_position_maps[root.i][match_id].append(start)
|
|
return keys_to_position_maps
|
|
|
|
def __call__(self, object doclike):
|
|
"""Find all token sequences matching the supplied pattern.
|
|
|
|
doclike (Doc or Span): The document to match over.
|
|
RETURNS (list): A list of `(key, start, end)` tuples,
|
|
describing the matches. A match tuple describes a span
|
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
|
"""
|
|
if isinstance(doclike, Doc):
|
|
doc = doclike
|
|
elif isinstance(doclike, Span):
|
|
doc = doclike.as_doc(copy_user_data=True)
|
|
else:
|
|
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
|
|
|
|
matched_key_trees = []
|
|
keys_to_position_maps = self._get_keys_to_position_maps(doc)
|
|
cache = {}
|
|
|
|
for key in self._patterns.keys():
|
|
for root, tree, tokens_to_key in zip(
|
|
self._root[key],
|
|
self._tree[key],
|
|
self._tokens_to_key[key],
|
|
):
|
|
for keys_to_position in keys_to_position_maps.values():
|
|
for matched_tree in self._get_matches(cache, doc, tree, tokens_to_key, keys_to_position):
|
|
matched_key_trees.append((key, matched_tree))
|
|
|
|
for i, (match_id, nodes) in enumerate(matched_key_trees):
|
|
on_match = self._callbacks.get(match_id)
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matched_key_trees)
|
|
|
|
return matched_key_trees
|
|
|
|
def _get_matches(self, cache, doc, tree, tokens_to_key, keys_to_position):
|
|
cdef bint is_valid
|
|
|
|
all_positions = [keys_to_position[key] for key in tokens_to_key]
|
|
|
|
# Generate all potential matches by computing the cartesian product of all
|
|
# position of the matched tokens
|
|
for candidate_match in product(*all_positions):
|
|
|
|
# A potential match is a valid match if all relationships between the
|
|
# matched tokens are satisfied.
|
|
is_valid = True
|
|
for left_idx in range(len(candidate_match)):
|
|
is_valid = self._check_relationships(cache, doc, candidate_match, left_idx, tree)
|
|
|
|
if not is_valid:
|
|
break
|
|
|
|
if is_valid:
|
|
yield list(candidate_match)
|
|
|
|
def _check_relationships(self, cache, doc, candidate_match, left_idx, tree):
|
|
# Position of the left operand within the document
|
|
left_pos = candidate_match[left_idx]
|
|
|
|
for relop, right_idx in tree[left_idx]:
|
|
|
|
# Position of the left operand within the document
|
|
right_pos = candidate_match[right_idx]
|
|
|
|
# List of valid right matches
|
|
valid_right_matches = self._resolve_node_operator(cache, doc, left_pos, relop)
|
|
|
|
# If the proposed right token is not within the valid ones, fail
|
|
if right_pos not in valid_right_matches:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _resolve_node_operator(self, cache, doc, node, operator):
|
|
"""
|
|
Given a doc, a node (as a index in the doc) and a REL_OP operator
|
|
returns the list of nodes from the doc that belong to node+operator.
|
|
"""
|
|
key = (node, operator)
|
|
if key not in cache:
|
|
cache[key] = [token.i for token in self._ops[operator](doc, node)]
|
|
return cache[key]
|
|
|
|
def _dep(self, doc, node):
|
|
if doc[node].head == doc[node]:
|
|
return []
|
|
return [doc[node].head]
|
|
|
|
def _gov(self, doc, node):
|
|
return list(doc[node].children)
|
|
|
|
def _dep_chain(self, doc, node):
|
|
return list(doc[node].ancestors)
|
|
|
|
def _gov_chain(self, doc, node):
|
|
return [t for t in doc[node].subtree if t != doc[node]]
|
|
|
|
def _imm_precede(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
if node < len(doc) - 1 and doc[node + 1] in sent:
|
|
return [doc[node + 1]]
|
|
return []
|
|
|
|
def _precede(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
return [doc[i] for i in range(node + 1, sent.end)]
|
|
|
|
def _imm_follow(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
if node > 0 and doc[node - 1] in sent:
|
|
return [doc[node - 1]]
|
|
return []
|
|
|
|
def _follow(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
return [doc[i] for i in range(sent.start, node)]
|
|
|
|
def _imm_right_sib(self, doc, node):
|
|
for child in list(doc[node].head.children):
|
|
if child.i == node + 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def _imm_left_sib(self, doc, node):
|
|
for child in list(doc[node].head.children):
|
|
if child.i == node - 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def _right_sib(self, doc, node):
|
|
return [doc[child.i] for child in doc[node].head.children if child.i > node]
|
|
|
|
def _left_sib(self, doc, node):
|
|
return [doc[child.i] for child in doc[node].head.children if child.i < node]
|
|
|
|
def _imm_right_child(self, doc, node):
|
|
for child in doc[node].rights:
|
|
if child.i == node + 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def _imm_left_child(self, doc, node):
|
|
for child in doc[node].lefts:
|
|
if child.i == node - 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def _right_child(self, doc, node):
|
|
return [child for child in doc[node].rights]
|
|
|
|
def _left_child(self, doc, node):
|
|
return [child for child in doc[node].lefts]
|
|
|
|
def _imm_right_parent(self, doc, node):
|
|
if doc[node].head.i == node + 1:
|
|
return [doc[node].head]
|
|
return []
|
|
|
|
def _imm_left_parent(self, doc, node):
|
|
if doc[node].head.i == node - 1:
|
|
return [doc[node].head]
|
|
return []
|
|
|
|
def _right_parent(self, doc, node):
|
|
if doc[node].head.i > node:
|
|
return [doc[node].head]
|
|
return []
|
|
|
|
def _left_parent(self, doc, node):
|
|
if doc[node].head.i < node:
|
|
return [doc[node].head]
|
|
return []
|
|
|
|
def _normalize_key(self, key):
|
|
if isinstance(key, str):
|
|
return self.vocab.strings.add(key)
|
|
else:
|
|
return key
|
|
|
|
def _get_sent(self, token):
|
|
root = (list(token.ancestors) or [token])[-1]
|
|
return token.doc[root.left_edge.i:root.right_edge.i + 1]
|
|
|
|
|
|
def unpickle_matcher(vocab, patterns, callbacks):
|
|
matcher = DependencyMatcher(vocab)
|
|
for key, pattern in patterns.items():
|
|
callback = callbacks.get(key, None)
|
|
matcher.add(key, pattern, on_match=callback)
|
|
return matcher
|