mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 16:54:24 +03:00
296446a1c8
<!--- Provide a general summary of your changes in the title. --> ## Description * tidy up and adjust Cython code to code style * improve docstrings and make calling `help()` nicer * add URLs to new docs pages to docstrings wherever possible, mostly to user-facing objects * fix various typos and inconsistencies in docs ### Types of change enhancement, docs ## Checklist <!--- Before you submit the PR, go over this checklist and make sure you can tick off all the boxes. [] -> [x] --> - [x] I have submitted the spaCy Contributor Agreement. - [x] I ran the tests, and all new and existing tests passed. - [x] My changes don't require a change to the documentation, or if they do, I've added all required information.
363 lines
14 KiB
Cython
363 lines
14 KiB
Cython
# cython: infer_types=True
|
|
# cython: profile=True
|
|
from __future__ import unicode_literals
|
|
|
|
from cymem.cymem cimport Pool
|
|
from preshed.maps cimport PreshMap
|
|
|
|
from .matcher cimport Matcher
|
|
from ..vocab cimport Vocab
|
|
from ..tokens.doc cimport Doc
|
|
|
|
from .matcher import unpickle_matcher
|
|
from ..errors import Errors
|
|
|
|
|
|
DELIMITER = "||"
|
|
INDEX_HEAD = 1
|
|
INDEX_RELOP = 0
|
|
|
|
|
|
cdef class DependencyTreeMatcher:
|
|
"""Match dependency parse tree based on pattern rules."""
|
|
cdef Pool mem
|
|
cdef readonly Vocab vocab
|
|
cdef readonly Matcher token_matcher
|
|
cdef public object _patterns
|
|
cdef public object _keys_to_token
|
|
cdef public object _root
|
|
cdef public object _entities
|
|
cdef public object _callbacks
|
|
cdef public object _nodes
|
|
cdef public object _tree
|
|
|
|
def __init__(self, vocab):
|
|
"""Create the DependencyTreeMatcher.
|
|
|
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
|
documents the matcher will operate on.
|
|
RETURNS (DependencyTreeMatcher): The newly constructed object.
|
|
"""
|
|
size = 20
|
|
self.token_matcher = Matcher(vocab)
|
|
self._keys_to_token = {}
|
|
self._patterns = {}
|
|
self._root = {}
|
|
self._nodes = {}
|
|
self._tree = {}
|
|
self._entities = {}
|
|
self._callbacks = {}
|
|
self.vocab = vocab
|
|
self.mem = Pool()
|
|
|
|
def __reduce__(self):
|
|
data = (self.vocab, self._patterns,self._tree, self._callbacks)
|
|
return (unpickle_matcher, data, None, None)
|
|
|
|
def __len__(self):
|
|
"""Get the number of rules, which are edges, added to the dependency
|
|
tree matcher.
|
|
|
|
RETURNS (int): The number of rules.
|
|
"""
|
|
return len(self._patterns)
|
|
|
|
def __contains__(self, key):
|
|
"""Check whether the matcher contains rules for a match ID.
|
|
|
|
key (unicode): The match ID.
|
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
|
"""
|
|
return self._normalize_key(key) in self._patterns
|
|
|
|
def validateInput(self, pattern, key):
|
|
idx = 0
|
|
visitedNodes = {}
|
|
for relation in pattern:
|
|
if "PATTERN" not in relation or "SPEC" not in relation:
|
|
raise ValueError(Errors.E098.format(key=key))
|
|
if idx == 0:
|
|
if not(
|
|
"NODE_NAME" in relation["SPEC"]
|
|
and "NBOR_RELOP" not in relation["SPEC"]
|
|
and "NBOR_NAME" not in relation["SPEC"]
|
|
):
|
|
raise ValueError(Errors.E099.format(key=key))
|
|
visitedNodes[relation["SPEC"]["NODE_NAME"]] = True
|
|
else:
|
|
if not(
|
|
"NODE_NAME" in relation["SPEC"]
|
|
and "NBOR_RELOP" in relation["SPEC"]
|
|
and "NBOR_NAME" in relation["SPEC"]
|
|
):
|
|
raise ValueError(Errors.E100.format(key=key))
|
|
if (
|
|
relation["SPEC"]["NODE_NAME"] in visitedNodes
|
|
or relation["SPEC"]["NBOR_NAME"] not in visitedNodes
|
|
):
|
|
raise ValueError(Errors.E101.format(key=key))
|
|
visitedNodes[relation["SPEC"]["NODE_NAME"]] = True
|
|
visitedNodes[relation["SPEC"]["NBOR_NAME"]] = True
|
|
idx = idx + 1
|
|
|
|
def add(self, key, on_match, *patterns):
|
|
for pattern in patterns:
|
|
if len(pattern) == 0:
|
|
raise ValueError(Errors.E012.format(key=key))
|
|
self.validateInput(pattern,key)
|
|
key = self._normalize_key(key)
|
|
_patterns = []
|
|
for pattern in patterns:
|
|
token_patterns = []
|
|
for i in range(len(pattern)):
|
|
token_pattern = [pattern[i]["PATTERN"]]
|
|
token_patterns.append(token_pattern)
|
|
# self.patterns.append(token_patterns)
|
|
_patterns.append(token_patterns)
|
|
self._patterns.setdefault(key, [])
|
|
self._callbacks[key] = on_match
|
|
self._patterns[key].extend(_patterns)
|
|
# Add each node pattern of all the input patterns individually to the
|
|
# matcher. This enables only a single instance of Matcher to be used.
|
|
# Multiple adds are required to track each node pattern.
|
|
_keys_to_token_list = []
|
|
for i in range(len(_patterns)):
|
|
_keys_to_token = {}
|
|
# TODO: Better ways to hash edges in pattern?
|
|
for j in range(len(_patterns[i])):
|
|
k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
|
|
self.token_matcher.add(k, None, _patterns[i][j])
|
|
_keys_to_token[k] = j
|
|
_keys_to_token_list.append(_keys_to_token)
|
|
self._keys_to_token.setdefault(key, [])
|
|
self._keys_to_token[key].extend(_keys_to_token_list)
|
|
_nodes_list = []
|
|
for pattern in patterns:
|
|
nodes = {}
|
|
for i in range(len(pattern)):
|
|
nodes[pattern[i]["SPEC"]["NODE_NAME"]] = i
|
|
_nodes_list.append(nodes)
|
|
self._nodes.setdefault(key, [])
|
|
self._nodes[key].extend(_nodes_list)
|
|
# Create an object tree to traverse later on. This data structure
|
|
# enables easy tree pattern match. Doc-Token based tree cannot be
|
|
# reused since it is memory-heavy and tightly coupled with the Doc.
|
|
self.retrieve_tree(patterns, _nodes_list,key)
|
|
|
|
def retrieve_tree(self, patterns, _nodes_list, key):
|
|
_heads_list = []
|
|
_root_list = []
|
|
for i in range(len(patterns)):
|
|
heads = {}
|
|
root = -1
|
|
for j in range(len(patterns[i])):
|
|
token_pattern = patterns[i][j]
|
|
if ("NBOR_RELOP" not in token_pattern["SPEC"]):
|
|
heads[j] = ('root', j)
|
|
root = j
|
|
else:
|
|
heads[j] = (
|
|
token_pattern["SPEC"]["NBOR_RELOP"],
|
|
_nodes_list[i][token_pattern["SPEC"]["NBOR_NAME"]]
|
|
)
|
|
_heads_list.append(heads)
|
|
_root_list.append(root)
|
|
_tree_list = []
|
|
for i in range(len(patterns)):
|
|
tree = {}
|
|
for j in range(len(patterns[i])):
|
|
if(_heads_list[i][j][INDEX_HEAD] == j):
|
|
continue
|
|
head = _heads_list[i][j][INDEX_HEAD]
|
|
if(head not in tree):
|
|
tree[head] = []
|
|
tree[head].append((_heads_list[i][j][INDEX_RELOP], j))
|
|
_tree_list.append(tree)
|
|
self._tree.setdefault(key, [])
|
|
self._tree[key].extend(_tree_list)
|
|
self._root.setdefault(key, [])
|
|
self._root[key].extend(_root_list)
|
|
|
|
def has_key(self, key):
|
|
"""Check whether the matcher has a rule with a given key.
|
|
|
|
key (string or int): The key to check.
|
|
RETURNS (bool): Whether the matcher has the rule.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
return key in self._patterns
|
|
|
|
def get(self, key, default=None):
|
|
"""Retrieve the pattern stored for a key.
|
|
|
|
key (unicode or int): The key to retrieve.
|
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
|
"""
|
|
key = self._normalize_key(key)
|
|
if key not in self._patterns:
|
|
return default
|
|
return (self._callbacks[key], self._patterns[key])
|
|
|
|
def __call__(self, Doc doc):
|
|
matched_trees = []
|
|
matches = self.token_matcher(doc)
|
|
for key in list(self._patterns.keys()):
|
|
_patterns_list = self._patterns[key]
|
|
_keys_to_token_list = self._keys_to_token[key]
|
|
_root_list = self._root[key]
|
|
_tree_list = self._tree[key]
|
|
_nodes_list = self._nodes[key]
|
|
length = len(_patterns_list)
|
|
for i in range(length):
|
|
_keys_to_token = _keys_to_token_list[i]
|
|
_root = _root_list[i]
|
|
_tree = _tree_list[i]
|
|
_nodes = _nodes_list[i]
|
|
id_to_position = {}
|
|
for i in range(len(_nodes)):
|
|
id_to_position[i]=[]
|
|
# TODO: This could be taken outside to improve running time..?
|
|
for match_id, start, end in matches:
|
|
if match_id in _keys_to_token:
|
|
id_to_position[_keys_to_token[match_id]].append(start)
|
|
_node_operator_map = self.get_node_operator_map(
|
|
doc,
|
|
_tree,
|
|
id_to_position,
|
|
_nodes,_root
|
|
)
|
|
length = len(_nodes)
|
|
if _root in id_to_position:
|
|
candidates = id_to_position[_root]
|
|
for candidate in candidates:
|
|
isVisited = {}
|
|
self.dfs(
|
|
candidate,
|
|
_root,_tree,
|
|
id_to_position,
|
|
doc,
|
|
isVisited,
|
|
_node_operator_map
|
|
)
|
|
# To check if the subtree pattern is completely
|
|
# identified. This is a heuristic. This is done to
|
|
# reduce the complexity of exponential unordered subtree
|
|
# matching. Will give approximate matches in some cases.
|
|
if(len(isVisited) == length):
|
|
matched_trees.append((key,list(isVisited)))
|
|
for i, (ent_id, nodes) in enumerate(matched_trees):
|
|
on_match = self._callbacks.get(ent_id)
|
|
if on_match is not None:
|
|
on_match(self, doc, i, matches)
|
|
return matched_trees
|
|
|
|
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map):
|
|
if (root in id_to_position and candidate in id_to_position[root]):
|
|
# Color the node since it is valid
|
|
isVisited[candidate] = True
|
|
if root in tree:
|
|
for root_child in tree[root]:
|
|
if (
|
|
candidate in _node_operator_map
|
|
and root_child[INDEX_RELOP] in _node_operator_map[candidate]
|
|
):
|
|
candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]]
|
|
for candidate_child in candidate_children:
|
|
result = self.dfs(
|
|
candidate_child.i,
|
|
root_child[INDEX_HEAD],
|
|
tree,
|
|
id_to_position,
|
|
doc,
|
|
isVisited,
|
|
_node_operator_map
|
|
)
|
|
|
|
# Given a node and an edge operator, to return the list of nodes
|
|
# from the doc that belong to node+operator. This is used to store
|
|
# all the results beforehand to prevent unnecessary computation while
|
|
# pattern matching
|
|
# _node_operator_map[node][operator] = [...]
|
|
def get_node_operator_map(self,doc,tree,id_to_position,nodes,root):
|
|
_node_operator_map = {}
|
|
all_node_indices = nodes.values()
|
|
all_operators = []
|
|
for node in all_node_indices:
|
|
if node in tree:
|
|
for child in tree[node]:
|
|
all_operators.append(child[INDEX_RELOP])
|
|
all_operators = list(set(all_operators))
|
|
all_nodes = []
|
|
for node in all_node_indices:
|
|
all_nodes = all_nodes + id_to_position[node]
|
|
all_nodes = list(set(all_nodes))
|
|
for node in all_nodes:
|
|
_node_operator_map[node] = {}
|
|
for operator in all_operators:
|
|
_node_operator_map[node][operator] = []
|
|
# Used to invoke methods for each operator
|
|
switcher = {
|
|
"<": self.dep,
|
|
">": self.gov,
|
|
">>": self.dep_chain,
|
|
"<<": self.gov_chain,
|
|
".": self.imm_precede,
|
|
"$+": self.imm_right_sib,
|
|
"$-": self.imm_left_sib,
|
|
"$++": self.right_sib,
|
|
"$--": self.left_sib
|
|
}
|
|
for operator in all_operators:
|
|
for node in all_nodes:
|
|
_node_operator_map[node][operator] = switcher.get(operator)(doc,node)
|
|
return _node_operator_map
|
|
|
|
def dep(self, doc, node):
|
|
return list(doc[node].head)
|
|
|
|
def gov(self,doc,node):
|
|
return list(doc[node].children)
|
|
|
|
def dep_chain(self, doc, node):
|
|
return list(doc[node].ancestors)
|
|
|
|
def gov_chain(self, doc, node):
|
|
return list(doc[node].subtree)
|
|
|
|
def imm_precede(self, doc, node):
|
|
if node > 0:
|
|
return [doc[node - 1]]
|
|
return []
|
|
|
|
def imm_right_sib(self, doc, node):
|
|
for idx in range(list(doc[node].head.children)):
|
|
if idx == node - 1:
|
|
return [doc[idx]]
|
|
return []
|
|
|
|
def imm_left_sib(self, doc, node):
|
|
for idx in range(list(doc[node].head.children)):
|
|
if idx == node + 1:
|
|
return [doc[idx]]
|
|
return []
|
|
|
|
def right_sib(self, doc, node):
|
|
candidate_children = []
|
|
for idx in range(list(doc[node].head.children)):
|
|
if idx < node:
|
|
candidate_children.append(doc[idx])
|
|
return candidate_children
|
|
|
|
def left_sib(self, doc, node):
|
|
candidate_children = []
|
|
for idx in range(list(doc[node].head.children)):
|
|
if idx > node:
|
|
candidate_children.append(doc[idx])
|
|
return candidate_children
|
|
|
|
def _normalize_key(self, key):
|
|
if isinstance(key, basestring):
|
|
return self.vocab.strings.add(key)
|
|
else:
|
|
return key
|