mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-29 11:26:28 +03:00
698b8b495f
* Clean up old Matcher call style related stuff In v2 Matcher.add was called with (key, on_match, *patterns). In v3 this was changed to (key, patterns, *, on_match=None), but there were various points where the old call syntax was documented or handled specially. This removes all those. The Matcher itself didn't need any code changes, as it just gives a generic type error. However the PhraseMatcher required some changes because it would automatically "fix" the old call style. Surprisingly, the tokenizer was still using the old call style in one place. After these changes tests failed in two places: 1. one test for the "new" call style, including the "old" call style. I removed this test. 2. deserializing the PhraseMatcher fails because the input docs are a set. I am not sure why 2 is happening - I guess it's a quirk of the serialization format? - so for now I just convert the set to a list when deserializing. The check that the input Docs are a List in the PhraseMatcher is a new check, but makes it parallel with the other Matchers, which seemed like the right thing to do. * Add notes related to input docs / deserialization type * Remove Typing import * Remove old note about call style change * Apply suggestions from code review Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Use separate method for setting internal doc representations In addition to the title change, this changes the internal dict to be a defaultdict, instead of a dict with frequent use of setdefault. * Add _add_from_arrays for unpickling * Cleanup around adding from arrays This moves adding to internal structures into the private batch method, and removes the single-add method. This has one behavioral change for `add`, in that if something is wrong with the list of input Docs (such as one of the items not being a Doc), valid items before the invalid one will not be added. Also the callback will not be updated if anything is invalid. This change should not be significant. This also adds a test to check failure when given a non-Doc. * Update spacy/matcher/phrasematcher.pyx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
463 lines
17 KiB
Cython
463 lines
17 KiB
Cython
# cython: infer_types=True, profile=True
|
|
from typing import List
|
|
from collections import defaultdict
|
|
from itertools import product
|
|
|
|
import warnings
|
|
|
|
from .matcher cimport Matcher
|
|
from ..vocab cimport Vocab
|
|
from ..tokens.doc cimport Doc
|
|
|
|
from ..errors import Errors, Warnings
|
|
from ..tokens import Span
|
|
|
|
|
|
DELIMITER = "||"
|
|
INDEX_HEAD = 1
|
|
INDEX_RELOP = 0
|
|
|
|
|
|
cdef class DependencyMatcher:
|
|
"""Match dependency parse tree based on pattern rules."""
|
|
cdef readonly Vocab vocab
|
|
cdef readonly Matcher _matcher
|
|
cdef public object _patterns
|
|
cdef public object _raw_patterns
|
|
cdef public object _tokens_to_key
|
|
cdef public object _root
|
|
cdef public object _callbacks
|
|
cdef public object _tree
|
|
cdef public object _ops
|
|
|
|
def __init__(self, vocab, *, validate=False):
|
|
"""Create the DependencyMatcher.
|
|
|
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
|
documents the matcher will operate on.
|
|
validate (bool): Whether patterns should be validated, passed to
|
|
Matcher as `validate`
|
|
"""
|
|
self._matcher = Matcher(vocab, validate=validate)
|
|
|
|
# Associates each key to the raw patterns list added by the user
|
|
# e.g. {'key': [[{'RIGHT_ID': ..., 'RIGHT_ATTRS': ... }, ... ], ... ], ... }
|
|
self._raw_patterns = defaultdict(list)
|
|
|
|
# Associates each key to a list of lists of 'RIGHT_ATTRS'
|
|
# e.g. {'key': [[{'POS': ... }, ... ], ... ], ... }
|
|
self._patterns = defaultdict(list)
|
|
|
|
# Associates each key to a list of lists where each list associates
|
|
# a token position within the pattern to the key used by the internal matcher)
|
|
# e.g. {'key': [ ['token_key', ... ], ... ], ... }
|
|
self._tokens_to_key = defaultdict(list)
|
|
|
|
# Associates each key to a list of ints where each int corresponds
|
|
# to the position of the root token within the pattern
|
|
# e.g. {'key': [3, 1, 4, ... ], ... }
|
|
self._root = defaultdict(list)
|
|
|
|
# Associates each key to a list of dicts where each dict describes all
|
|
# branches from a token (identifed by its position) to other tokens as
|
|
# a list of tuples: ('REL_OP', 'LEFT_ID')
|
|
# e.g. {'key': [{2: [('<', 'left_id'), ...] ...}, ... ], ...}
|
|
self._tree = defaultdict(list)
|
|
|
|
# Associates each key to its on_match callback
|
|
# e.g. {'key': on_match_callback, ...}
|
|
self._callbacks = {}
|
|
|
|
self.vocab = vocab
|
|
self._ops = {
|
|
"<": self._dep,
|
|
">": self._gov,
|
|
"<<": self._dep_chain,
|
|
">>": self._gov_chain,
|
|
".": self._imm_precede,
|
|
".*": self._precede,
|
|
";": self._imm_follow,
|
|
";*": self._follow,
|
|
"$+": self._imm_right_sib,
|
|
"$-": self._imm_left_sib,
|
|
"$++": self._right_sib,
|
|
"$--": self._left_sib,
|
|
">++": self._right_child,
|
|
">--": self._left_child,
|
|
"<++": self._right_parent,
|
|
"<--": self._left_parent,
|
|
}
|
|
|
|
def __reduce__(self):
|
|
data = (self.vocab, self._raw_patterns, self._callbacks)
|
|
return (unpickle_matcher, data, None, None)
|
|
|
|
def __len__(self):
|
|
"""Get the number of rules, which are edges, added to the dependency
|
|
tree matcher.
|
|
|
|
RETURNS (int): The number of rules.
|
|
"""
|
|
return len(self._patterns)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (str): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
"""
|
|
return self.has_key(key)
|
|
|
|
def _validate_input(self, pattern, key):
|
|
idx = 0
|
|
visited_nodes = {}
|
|
for relation in pattern:
|
|
if not isinstance(relation, dict):
|
|
raise ValueError(Errors.E1008)
|
|
if "RIGHT_ATTRS" not in relation and "RIGHT_ID" not in relation:
|
|
raise ValueError(Errors.E098.format(key=key))
|
|
if idx == 0:
|
|
if not(
|
|
"RIGHT_ID" in relation
|
|
and "REL_OP" not in relation
|
|
and "LEFT_ID" not in relation
|
|
):
|
|
raise ValueError(Errors.E099.format(key=key))
|
|
visited_nodes[relation["RIGHT_ID"]] = True
|
|
else:
|
|
required_keys = {"RIGHT_ID", "RIGHT_ATTRS", "REL_OP", "LEFT_ID"}
|
|
relation_keys = set(relation.keys())
|
|
missing = required_keys - relation_keys
|
|
if missing:
|
|
missing_txt = ", ".join(list(missing))
|
|
raise ValueError(Errors.E100.format(
|
|
required=required_keys,
|
|
missing=missing_txt
|
|
))
|
|
if (
|
|
relation["RIGHT_ID"] in visited_nodes
|
|
or relation["LEFT_ID"] not in visited_nodes
|
|
):
|
|
raise ValueError(Errors.E101.format(key=key))
|
|
if relation["REL_OP"] not in self._ops:
|
|
raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
|
|
visited_nodes[relation["RIGHT_ID"]] = True
|
|
visited_nodes[relation["LEFT_ID"]] = True
|
|
if relation["RIGHT_ATTRS"].get("OP", "") in ("?", "*", "+"):
|
|
raise ValueError(Errors.E1016.format(node=relation))
|
|
idx = idx + 1
|
|
|
|
def _get_matcher_key(self, key, pattern_idx, token_idx):
|
|
"""
|
|
Creates a token key to be used by the matcher
|
|
"""
|
|
return self._normalize_key(
|
|
str(key) + DELIMITER +
|
|
str(pattern_idx) + DELIMITER +
|
|
str(token_idx)
|
|
)
|
|
|
|
def add(self, key, patterns, *, on_match=None):
|
|
"""Add a new matcher rule to the matcher.
|
|
|
|
key (str): The match ID.
|
|
patterns (list): The patterns to add for the given key.
|
|
on_match (callable): Optional callback executed on match.
|
|
"""
|
|
if on_match is not None and not hasattr(on_match, "__call__"):
|
|
raise ValueError(Errors.E171.format(name="DependencyMatcher", arg_type=type(on_match)))
|
|
if patterns is None or not isinstance(patterns, List):
|
|
raise ValueError(Errors.E948.format(name="DependencyMatcher", arg_type=type(patterns)))
|
|
for pattern in patterns:
|
|
if len(pattern) == 0:
|
|
raise ValueError(Errors.E012.format(key=key))
|
|
self._validate_input(pattern, key)
|
|
|
|
key = self._normalize_key(key)
|
|
|
|
# Save raw patterns and on_match callback
|
|
self._raw_patterns[key].extend(patterns)
|
|
self._callbacks[key] = on_match
|
|
|
|
# Add 'RIGHT_ATTRS' to self._patterns[key]
|
|
_patterns = [[[pat["RIGHT_ATTRS"]] for pat in pattern] for pattern in patterns]
|
|
pattern_offset = len(self._patterns[key])
|
|
self._patterns[key].extend(_patterns)
|
|
|
|
# Add each node pattern of all the input patterns individually to the
|
|
# matcher. This enables only a single instance of Matcher to be used.
|
|
# Multiple adds are required to track each node pattern.
|
|
tokens_to_key_list = []
|
|
for i, current_patterns in enumerate(_patterns, start=pattern_offset):
|
|
|
|
# Preallocate list space
|
|
tokens_to_key = [None] * len(current_patterns)
|
|
|
|
# TODO: Better ways to hash edges in pattern?
|
|
for j, _pattern in enumerate(current_patterns):
|
|
k = self._get_matcher_key(key, i, j)
|
|
self._matcher.add(k, [_pattern])
|
|
tokens_to_key[j] = k
|
|
|
|
tokens_to_key_list.append(tokens_to_key)
|
|
|
|
self._tokens_to_key[key].extend(tokens_to_key_list)
|
|
|
|
# Create an object tree to traverse later on. This data structure
|
|
# enables easy tree pattern match. Doc-Token based tree cannot be
|
|
# reused since it is memory-heavy and tightly coupled with the Doc.
|
|
self._retrieve_tree(patterns, key)
|
|
|
|
def _retrieve_tree(self, patterns, key):
|
|
|
|
# New trees belonging to this key
|
|
tree_list = []
|
|
|
|
# List of indices to the root nodes
|
|
root_list = []
|
|
|
|
# Group token id to create trees
|
|
for i, pattern in enumerate(patterns):
|
|
tree = defaultdict(list)
|
|
root = -1
|
|
|
|
right_id_to_token = {}
|
|
for j, token in enumerate(pattern):
|
|
right_id_to_token[token["RIGHT_ID"]] = j
|
|
|
|
for j, token in enumerate(pattern):
|
|
if "REL_OP" in token:
|
|
# Add tree branch
|
|
tree[right_id_to_token[token["LEFT_ID"]]].append(
|
|
(token["REL_OP"], j),
|
|
)
|
|
else:
|
|
# No 'REL_OP', this is the root
|
|
root = j
|
|
|
|
tree_list.append(tree)
|
|
root_list.append(root)
|
|
|
|
self._tree[key].extend(tree_list)
|
|
self._root[key].extend(root_list)
|
|
|
|
def has_key(self, key):
|
|
"""Check whether the matcher has a rule with a given key.
|
|
|
|
key (string or int): The key to check.
|
|
RETURNS (bool): Whether the matcher has the rule.
|
|
"""
|
|
return self._normalize_key(key) in self._patterns
|
|
|
|
def get(self, key, default=None):
|
|
"""Retrieve the pattern stored for a key.
|
|
|
|
key (str / int): The key to retrieve.
|
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
if key not in self._raw_patterns:
|
|
return default
|
|
return (self._callbacks[key], self._raw_patterns[key])
|
|
|
|
def remove(self, key):
|
|
key = self._normalize_key(key)
|
|
if not key in self._patterns:
|
|
raise ValueError(Errors.E175.format(key=key))
|
|
self._patterns.pop(key)
|
|
self._raw_patterns.pop(key)
|
|
self._tree.pop(key)
|
|
self._root.pop(key)
|
|
for mklist in self._tokens_to_key.pop(key):
|
|
for mkey in mklist:
|
|
self._matcher.remove(mkey)
|
|
|
|
def _get_keys_to_position_maps(self, doc):
|
|
"""
|
|
Processes the doc and groups all matches by their root and match id.
|
|
Returns a dict mapping each (root, match id) pair to the list of
|
|
tokens indices which are descendants of root and match the token
|
|
pattern identified by the given match id.
|
|
|
|
e.g. keys_to_position_maps[root_index][match_id] = [...]
|
|
"""
|
|
keys_to_position_maps = defaultdict(lambda: defaultdict(list))
|
|
for match_id, start, end in self._matcher(doc):
|
|
if start + 1 != end:
|
|
warnings.warn(Warnings.W110.format(tokens=[t.text for t in doc[start:end]], pattern=self._matcher.get(match_id)[1][0][0]))
|
|
token = doc[start]
|
|
root = ([token] + list(token.ancestors))[-1]
|
|
keys_to_position_maps[root.i][match_id].append(start)
|
|
return keys_to_position_maps
|
|
|
|
def __call__(self, object doclike):
|
|
"""Find all token sequences matching the supplied pattern.
|
|
|
|
doclike (Doc or Span): The document to match over.
|
|
RETURNS (list): A list of `(key, start, end)` tuples,
|
|
describing the matches. A match tuple describes a span
|
|
`doc[start:end]`. The `label_id` and `key` are both integers.
|
|
"""
|
|
if isinstance(doclike, Doc):
|
|
doc = doclike
|
|
elif isinstance(doclike, Span):
|
|
doc = doclike.as_doc(copy_user_data=True)
|
|
else:
|
|
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
|
|
|
|
matched_key_trees = []
|
|
keys_to_position_maps = self._get_keys_to_position_maps(doc)
|
|
cache = {}
|
|
|
|
for key in self._patterns.keys():
|
|
for root, tree, tokens_to_key in zip(
|
|
self._root[key],
|
|
self._tree[key],
|
|
self._tokens_to_key[key],
|
|
):
|
|
for keys_to_position in keys_to_position_maps.values():
|
|
for matched_tree in self._get_matches(cache, doc, tree, tokens_to_key, keys_to_position):
|
|
matched_key_trees.append((key, matched_tree))
|
|
|
|
for i, (match_id, nodes) in enumerate(matched_key_trees):
|
|
on_match = self._callbacks.get(match_id)
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matched_key_trees)
|
|
|
|
return matched_key_trees
|
|
|
|
def _get_matches(self, cache, doc, tree, tokens_to_key, keys_to_position):
|
|
cdef bint is_valid
|
|
|
|
all_positions = [keys_to_position[key] for key in tokens_to_key]
|
|
|
|
# Generate all potential matches by computing the cartesian product of all
|
|
# position of the matched tokens
|
|
for candidate_match in product(*all_positions):
|
|
|
|
# A potential match is a valid match if all relationships between the
|
|
# matched tokens are satisfied.
|
|
is_valid = True
|
|
for left_idx in range(len(candidate_match)):
|
|
is_valid = self._check_relationships(cache, doc, candidate_match, left_idx, tree)
|
|
|
|
if not is_valid:
|
|
break
|
|
|
|
if is_valid:
|
|
yield list(candidate_match)
|
|
|
|
def _check_relationships(self, cache, doc, candidate_match, left_idx, tree):
|
|
# Position of the left operand within the document
|
|
left_pos = candidate_match[left_idx]
|
|
|
|
for relop, right_idx in tree[left_idx]:
|
|
|
|
# Position of the left operand within the document
|
|
right_pos = candidate_match[right_idx]
|
|
|
|
# List of valid right matches
|
|
valid_right_matches = self._resolve_node_operator(cache, doc, left_pos, relop)
|
|
|
|
# If the proposed right token is not within the valid ones, fail
|
|
if right_pos not in valid_right_matches:
|
|
return False
|
|
|
|
return True
|
|
|
|
def _resolve_node_operator(self, cache, doc, node, operator):
|
|
"""
|
|
Given a doc, a node (as a index in the doc) and a REL_OP operator
|
|
returns the list of nodes from the doc that belong to node+operator.
|
|
"""
|
|
key = (node, operator)
|
|
if key not in cache:
|
|
cache[key] = [token.i for token in self._ops[operator](doc, node)]
|
|
return cache[key]
|
|
|
|
def _dep(self, doc, node):
|
|
if doc[node].head == doc[node]:
|
|
return []
|
|
return [doc[node].head]
|
|
|
|
def _gov(self,doc,node):
|
|
return list(doc[node].children)
|
|
|
|
def _dep_chain(self, doc, node):
|
|
return list(doc[node].ancestors)
|
|
|
|
def _gov_chain(self, doc, node):
|
|
return [t for t in doc[node].subtree if t != doc[node]]
|
|
|
|
def _imm_precede(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
if node < len(doc) - 1 and doc[node + 1] in sent:
|
|
return [doc[node + 1]]
|
|
return []
|
|
|
|
def _precede(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
return [doc[i] for i in range(node + 1, sent.end)]
|
|
|
|
def _imm_follow(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
if node > 0 and doc[node - 1] in sent:
|
|
return [doc[node - 1]]
|
|
return []
|
|
|
|
def _follow(self, doc, node):
|
|
sent = self._get_sent(doc[node])
|
|
return [doc[i] for i in range(sent.start, node)]
|
|
|
|
def _imm_right_sib(self, doc, node):
|
|
for child in list(doc[node].head.children):
|
|
if child.i == node + 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def _imm_left_sib(self, doc, node):
|
|
for child in list(doc[node].head.children):
|
|
if child.i == node - 1:
|
|
return [doc[child.i]]
|
|
return []
|
|
|
|
def _right_sib(self, doc, node):
|
|
return [doc[child.i] for child in doc[node].head.children if child.i > node]
|
|
|
|
def _left_sib(self, doc, node):
|
|
return [doc[child.i] for child in doc[node].head.children if child.i < node]
|
|
|
|
def _right_child(self, doc, node):
|
|
return [doc[child.i] for child in doc[node].children if child.i > node]
|
|
|
|
def _left_child(self, doc, node):
|
|
return [doc[child.i] for child in doc[node].children if child.i < node]
|
|
|
|
def _right_parent(self, doc, node):
|
|
if doc[node].head.i > node:
|
|
return [doc[node].head]
|
|
return []
|
|
|
|
def _left_parent(self, doc, node):
|
|
if doc[node].head.i < node:
|
|
return [doc[node].head]
|
|
return []
|
|
|
|
def _normalize_key(self, key):
|
|
if isinstance(key, str):
|
|
return self.vocab.strings.add(key)
|
|
else:
|
|
return key
|
|
|
|
def _get_sent(self, token):
|
|
root = (list(token.ancestors) or [token])[-1]
|
|
return token.doc[root.left_edge.i:root.right_edge.i + 1]
|
|
|
|
|
|
def unpickle_matcher(vocab, patterns, callbacks):
|
|
matcher = DependencyMatcher(vocab)
|
|
for key, pattern in patterns.items():
|
|
callback = callbacks.get(key, None)
|
|
matcher.add(key, pattern, on_match=callback)
|
|
return matcher
|