mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Add official support for the `DependencyMatcher`. Redesign the pattern
specification. Fix and extend operator implementations. Update API docs
and add usage docs.
Patterns
--------
Refactor pattern structure to:
```
{
  "LEFT_ID": str,
  "REL_OP": str,
  "RIGHT_ID": str,
  "RIGHT_ATTRS": dict,
}
```
The first node contains only `RIGHT_ID` and `RIGHT_ATTRS` and all
subsequent nodes contain all four keys.
New operators
-------------
Because of the way patterns are constructed from left to right, it's
helpful to have `follows` operators along with `precedes` operators. Add
operators for simple precedes / follows alongside immediate precedes /
follows.
* `.*`: precedes
* `;`: immediately follows
* `;*`: follows
Operator fixes
--------------
* `<` and `<<` do not include the node itself
* Fix reversed order for all operators involving linear precedence (`.`,
  all sibling operators)
* Linear precedence operators do not match nodes outside the same parse
Additional fixes
----------------
* Use v3 Matcher API
* Support `get` and `remove`
* Support pickling
		
	
			
		
			
				
	
	
		
			419 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			419 lines
		
	
	
		
			16 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
| # cython: infer_types=True, profile=True
 | |
| from typing import List
 | |
| 
 | |
| import numpy
 | |
| 
 | |
| from cymem.cymem cimport Pool
 | |
| 
 | |
| from .matcher cimport Matcher
 | |
| from ..vocab cimport Vocab
 | |
| from ..tokens.doc cimport Doc
 | |
| 
 | |
| from ..errors import Errors
 | |
| from ..tokens import Span
 | |
| 
 | |
| 
 | |
| DELIMITER = "||"
 | |
| INDEX_HEAD = 1
 | |
| INDEX_RELOP = 0
 | |
| 
 | |
| 
 | |
| cdef class DependencyMatcher:
 | |
|     """Match dependency parse tree based on pattern rules."""
 | |
|     cdef Pool mem
 | |
|     cdef readonly Vocab vocab
 | |
|     cdef readonly Matcher matcher
 | |
|     cdef public object _patterns
 | |
|     cdef public object _raw_patterns
 | |
|     cdef public object _keys_to_token
 | |
|     cdef public object _root
 | |
|     cdef public object _callbacks
 | |
|     cdef public object _nodes
 | |
|     cdef public object _tree
 | |
|     cdef public object _ops
 | |
| 
 | |
|     def __init__(self, vocab, *, validate=False):
 | |
|         """Create the DependencyMatcher.
 | |
| 
 | |
|         vocab (Vocab): The vocabulary object, which must be shared with the
 | |
|             documents the matcher will operate on.
 | |
|         validate (bool): Whether patterns should be validated, passed to
 | |
|             Matcher as `validate`
 | |
|         """
 | |
|         size = 20
 | |
|         self.matcher = Matcher(vocab, validate=validate)
 | |
|         self._keys_to_token = {}
 | |
|         self._patterns = {}
 | |
|         self._raw_patterns = {}
 | |
|         self._root = {}
 | |
|         self._nodes = {}
 | |
|         self._tree = {}
 | |
|         self._callbacks = {}
 | |
|         self.vocab = vocab
 | |
|         self.mem = Pool()
 | |
|         self._ops = {
 | |
|             "<": self.dep,
 | |
|             ">": self.gov,
 | |
|             "<<": self.dep_chain,
 | |
|             ">>": self.gov_chain,
 | |
|             ".": self.imm_precede,
 | |
|             ".*": self.precede,
 | |
|             ";": self.imm_follow,
 | |
|             ";*": self.follow,
 | |
|             "$+": self.imm_right_sib,
 | |
|             "$-": self.imm_left_sib,
 | |
|             "$++": self.right_sib,
 | |
|             "$--": self.left_sib,
 | |
|         }
 | |
| 
 | |
|     def __reduce__(self):
 | |
|         data = (self.vocab, self._raw_patterns, self._callbacks)
 | |
|         return (unpickle_matcher, data, None, None)
 | |
| 
 | |
|     def __len__(self):
 | |
|         """Get the number of rules, which are edges, added to the dependency
 | |
|         tree matcher.
 | |
| 
 | |
|         RETURNS (int): The number of rules.
 | |
|         """
 | |
|         return len(self._patterns)
 | |
| 
 | |
|     def __contains__(self, key):
 | |
|         """Check whether the matcher contains rules for a match ID.
 | |
| 
 | |
|         key (str): The match ID.
 | |
|         RETURNS (bool): Whether the matcher contains rules for this match ID.
 | |
|         """
 | |
|         return self.has_key(key)
 | |
| 
 | |
|     def validate_input(self, pattern, key):
 | |
|         idx = 0
 | |
|         visited_nodes = {}
 | |
|         for relation in pattern:
 | |
|             if not isinstance(relation, dict):
 | |
|                 raise ValueError(Errors.E1008)
 | |
|             if "RIGHT_ATTRS" not in relation and "RIGHT_ID" not in relation:
 | |
|                 raise ValueError(Errors.E098.format(key=key))
 | |
|             if idx == 0:
 | |
|                 if not(
 | |
|                     "RIGHT_ID" in relation
 | |
|                     and "REL_OP" not in relation
 | |
|                     and "LEFT_ID" not in relation
 | |
|                 ):
 | |
|                     raise ValueError(Errors.E099.format(key=key))
 | |
|                 visited_nodes[relation["RIGHT_ID"]] = True
 | |
|             else:
 | |
|                 if not(
 | |
|                     "RIGHT_ID" in relation
 | |
|                     and "RIGHT_ATTRS" in relation
 | |
|                     and "REL_OP" in relation
 | |
|                     and "LEFT_ID" in relation
 | |
|                 ):
 | |
|                     raise ValueError(Errors.E100.format(key=key))
 | |
|                 if (
 | |
|                     relation["RIGHT_ID"] in visited_nodes
 | |
|                     or relation["LEFT_ID"] not in visited_nodes
 | |
|                 ):
 | |
|                     raise ValueError(Errors.E101.format(key=key))
 | |
|                 if relation["REL_OP"] not in self._ops:
 | |
|                     raise ValueError(Errors.E1007.format(op=relation["REL_OP"]))
 | |
|                 visited_nodes[relation["RIGHT_ID"]] = True
 | |
|                 visited_nodes[relation["LEFT_ID"]] = True
 | |
|             idx = idx + 1
 | |
| 
 | |
|     def add(self, key, patterns, *, on_match=None):
 | |
|         """Add a new matcher rule to the matcher.
 | |
| 
 | |
|         key (str): The match ID.
 | |
|         patterns (list): The patterns to add for the given key.
 | |
|         on_match (callable): Optional callback executed on match.
 | |
|         """
 | |
|         if on_match is not None and not hasattr(on_match, "__call__"):
 | |
|             raise ValueError(Errors.E171.format(arg_type=type(on_match)))
 | |
|         if patterns is None or not isinstance(patterns, List):  # old API
 | |
|             raise ValueError(Errors.E948.format(arg_type=type(patterns)))
 | |
|         for pattern in patterns:
 | |
|             if len(pattern) == 0:
 | |
|                 raise ValueError(Errors.E012.format(key=key))
 | |
|             self.validate_input(pattern, key)
 | |
|         key = self._normalize_key(key)
 | |
|         self._raw_patterns.setdefault(key, [])
 | |
|         self._raw_patterns[key].extend(patterns)
 | |
|         _patterns = []
 | |
|         for pattern in patterns:
 | |
|             token_patterns = []
 | |
|             for i in range(len(pattern)):
 | |
|                 token_pattern = [pattern[i]["RIGHT_ATTRS"]]
 | |
|                 token_patterns.append(token_pattern)
 | |
|             _patterns.append(token_patterns)
 | |
|         self._patterns.setdefault(key, [])
 | |
|         self._callbacks[key] = on_match
 | |
|         self._patterns[key].extend(_patterns)
 | |
|         # Add each node pattern of all the input patterns individually to the
 | |
|         # matcher. This enables only a single instance of Matcher to be used.
 | |
|         # Multiple adds are required to track each node pattern.
 | |
|         _keys_to_token_list = []
 | |
|         for i in range(len(_patterns)):
 | |
|             _keys_to_token = {}
 | |
|             # TODO: Better ways to hash edges in pattern?
 | |
|             for j in range(len(_patterns[i])):
 | |
|                 k = self._normalize_key(unicode(key) + DELIMITER + unicode(i) + DELIMITER + unicode(j))
 | |
|                 self.matcher.add(k, [_patterns[i][j]])
 | |
|                 _keys_to_token[k] = j
 | |
|             _keys_to_token_list.append(_keys_to_token)
 | |
|         self._keys_to_token.setdefault(key, [])
 | |
|         self._keys_to_token[key].extend(_keys_to_token_list)
 | |
|         _nodes_list = []
 | |
|         for pattern in patterns:
 | |
|             nodes = {}
 | |
|             for i in range(len(pattern)):
 | |
|                 nodes[pattern[i]["RIGHT_ID"]] = i
 | |
|             _nodes_list.append(nodes)
 | |
|         self._nodes.setdefault(key, [])
 | |
|         self._nodes[key].extend(_nodes_list)
 | |
|         # Create an object tree to traverse later on. This data structure
 | |
|         # enables easy tree pattern match. Doc-Token based tree cannot be
 | |
|         # reused since it is memory-heavy and tightly coupled with the Doc.
 | |
|         self.retrieve_tree(patterns, _nodes_list, key)
 | |
| 
 | |
|     def retrieve_tree(self, patterns, _nodes_list, key):
 | |
|         _heads_list = []
 | |
|         _root_list = []
 | |
|         for i in range(len(patterns)):
 | |
|             heads = {}
 | |
|             root = -1
 | |
|             for j in range(len(patterns[i])):
 | |
|                 token_pattern = patterns[i][j]
 | |
|                 if ("REL_OP" not in token_pattern):
 | |
|                     heads[j] = ('root', j)
 | |
|                     root = j
 | |
|                 else:
 | |
|                     heads[j] = (
 | |
|                         token_pattern["REL_OP"],
 | |
|                         _nodes_list[i][token_pattern["LEFT_ID"]]
 | |
|                     )
 | |
|             _heads_list.append(heads)
 | |
|             _root_list.append(root)
 | |
|         _tree_list = []
 | |
|         for i in range(len(patterns)):
 | |
|             tree = {}
 | |
|             for j in range(len(patterns[i])):
 | |
|                 if(_heads_list[i][j][INDEX_HEAD] == j):
 | |
|                     continue
 | |
|                 head = _heads_list[i][j][INDEX_HEAD]
 | |
|                 if(head not in tree):
 | |
|                     tree[head] = []
 | |
|                 tree[head].append((_heads_list[i][j][INDEX_RELOP], j))
 | |
|             _tree_list.append(tree)
 | |
|         self._tree.setdefault(key, [])
 | |
|         self._tree[key].extend(_tree_list)
 | |
|         self._root.setdefault(key, [])
 | |
|         self._root[key].extend(_root_list)
 | |
| 
 | |
|     def has_key(self, key):
 | |
|         """Check whether the matcher has a rule with a given key.
 | |
| 
 | |
|         key (string or int): The key to check.
 | |
|         RETURNS (bool): Whether the matcher has the rule.
 | |
|         """
 | |
|         return self._normalize_key(key) in self._patterns
 | |
| 
 | |
|     def get(self, key, default=None):
 | |
|         """Retrieve the pattern stored for a key.
 | |
| 
 | |
|         key (str / int): The key to retrieve.
 | |
|         RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
 | |
|         """
 | |
|         key = self._normalize_key(key)
 | |
|         if key not in self._raw_patterns:
 | |
|             return default
 | |
|         return (self._callbacks[key], self._raw_patterns[key])
 | |
| 
 | |
|     def remove(self, key):
 | |
|         key = self._normalize_key(key)
 | |
|         if not key in self._patterns:
 | |
|             raise ValueError(Errors.E175.format(key=key))
 | |
|         self._patterns.pop(key)
 | |
|         self._raw_patterns.pop(key)
 | |
|         self._nodes.pop(key)
 | |
|         self._tree.pop(key)
 | |
|         self._root.pop(key)
 | |
| 
 | |
|     def __call__(self, object doclike):
 | |
|         """Find all token sequences matching the supplied pattern.
 | |
| 
 | |
|         doclike (Doc or Span): The document to match over.
 | |
|         RETURNS (list): A list of `(key, start, end)` tuples,
 | |
|             describing the matches. A match tuple describes a span
 | |
|             `doc[start:end]`. The `label_id` and `key` are both integers.
 | |
|         """
 | |
|         if isinstance(doclike, Doc):
 | |
|             doc = doclike
 | |
|         elif isinstance(doclike, Span):
 | |
|             doc = doclike.as_doc()
 | |
|         else:
 | |
|             raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
 | |
|         matched_key_trees = []
 | |
|         matches = self.matcher(doc)
 | |
|         for key in list(self._patterns.keys()):
 | |
|             _patterns_list = self._patterns[key]
 | |
|             _keys_to_token_list = self._keys_to_token[key]
 | |
|             _root_list = self._root[key]
 | |
|             _tree_list = self._tree[key]
 | |
|             _nodes_list = self._nodes[key]
 | |
|             length = len(_patterns_list)
 | |
|             for i in range(length):
 | |
|                 _keys_to_token = _keys_to_token_list[i]
 | |
|                 _root = _root_list[i]
 | |
|                 _tree = _tree_list[i]
 | |
|                 _nodes = _nodes_list[i]
 | |
|                 id_to_position = {}
 | |
|                 for i in range(len(_nodes)):
 | |
|                     id_to_position[i]=[]
 | |
|                 # TODO: This could be taken outside to improve running time..?
 | |
|                 for match_id, start, end in matches:
 | |
|                     if match_id in _keys_to_token:
 | |
|                         id_to_position[_keys_to_token[match_id]].append(start)
 | |
|                 _node_operator_map = self.get_node_operator_map(
 | |
|                     doc,
 | |
|                     _tree,
 | |
|                     id_to_position,
 | |
|                     _nodes,_root
 | |
|                 )
 | |
|                 length = len(_nodes)
 | |
| 
 | |
|                 matched_trees = []
 | |
|                 self.recurse(_tree, id_to_position, _node_operator_map, 0, [], matched_trees)
 | |
|                 for matched_tree in matched_trees:
 | |
|                     matched_key_trees.append((key, matched_tree))
 | |
|             for i, (match_id, nodes) in enumerate(matched_key_trees):
 | |
|                 on_match = self._callbacks.get(match_id)
 | |
|                 if on_match is not None:
 | |
|                     on_match(self, doc, i, matched_key_trees)
 | |
|         return matched_key_trees
 | |
| 
 | |
|     def recurse(self, tree, id_to_position, _node_operator_map, int patternLength, visited_nodes, matched_trees):
 | |
|         cdef bint isValid;
 | |
|         if patternLength == len(id_to_position.keys()):
 | |
|             isValid = True
 | |
|             for node in range(patternLength):
 | |
|                 if node in tree:
 | |
|                     for idx, (relop,nbor) in enumerate(tree[node]):
 | |
|                         computed_nbors = numpy.asarray(_node_operator_map[visited_nodes[node]][relop])
 | |
|                         isNbor = False
 | |
|                         for computed_nbor in computed_nbors:
 | |
|                             if computed_nbor.i == visited_nodes[nbor]:
 | |
|                                 isNbor = True
 | |
|                         isValid = isValid & isNbor
 | |
|             if(isValid):
 | |
|                 matched_trees.append(visited_nodes)
 | |
|             return
 | |
|         allPatternNodes = numpy.asarray(id_to_position[patternLength])
 | |
|         for patternNode in allPatternNodes:
 | |
|             self.recurse(tree, id_to_position, _node_operator_map, patternLength+1, visited_nodes+[patternNode], matched_trees)
 | |
| 
 | |
|     # Given a node and an edge operator, to return the list of nodes
 | |
|     # from the doc that belong to node+operator. This is used to store
 | |
|     # all the results beforehand to prevent unnecessary computation while
 | |
|     # pattern matching
 | |
|     # _node_operator_map[node][operator] = [...]
 | |
|     def get_node_operator_map(self, doc, tree, id_to_position, nodes, root):
 | |
|         _node_operator_map = {}
 | |
|         all_node_indices = nodes.values()
 | |
|         all_operators = []
 | |
|         for node in all_node_indices:
 | |
|             if node in tree:
 | |
|                 for child in tree[node]:
 | |
|                     all_operators.append(child[INDEX_RELOP])
 | |
|         all_operators = list(set(all_operators))
 | |
|         all_nodes = []
 | |
|         for node in all_node_indices:
 | |
|             all_nodes = all_nodes + id_to_position[node]
 | |
|         all_nodes = list(set(all_nodes))
 | |
|         for node in all_nodes:
 | |
|             _node_operator_map[node] = {}
 | |
|             for operator in all_operators:
 | |
|                 _node_operator_map[node][operator] = []
 | |
|         for operator in all_operators:
 | |
|             for node in all_nodes:
 | |
|                 _node_operator_map[node][operator] = self._ops.get(operator)(doc, node)
 | |
|         return _node_operator_map
 | |
| 
 | |
|     def dep(self, doc, node):
 | |
|         if doc[node].head == doc[node]:
 | |
|             return []
 | |
|         return [doc[node].head]
 | |
| 
 | |
|     def gov(self,doc,node):
 | |
|         return list(doc[node].children)
 | |
| 
 | |
|     def dep_chain(self, doc, node):
 | |
|         return list(doc[node].ancestors)
 | |
| 
 | |
|     def gov_chain(self, doc, node):
 | |
|         return [t for t in doc[node].subtree if t != doc[node]]
 | |
| 
 | |
|     def imm_precede(self, doc, node):
 | |
|         sent = self._get_sent(doc[node])
 | |
|         if node < len(doc) - 1 and doc[node + 1] in sent:
 | |
|             return [doc[node + 1]]
 | |
|         return []
 | |
| 
 | |
|     def precede(self, doc, node):
 | |
|         sent = self._get_sent(doc[node])
 | |
|         return [doc[i] for i in range(node + 1, sent.end)]
 | |
| 
 | |
|     def imm_follow(self, doc, node):
 | |
|         sent = self._get_sent(doc[node])
 | |
|         if node > 0 and doc[node - 1] in sent:
 | |
|             return [doc[node - 1]]
 | |
|         return []
 | |
| 
 | |
|     def follow(self, doc, node):
 | |
|         sent = self._get_sent(doc[node])
 | |
|         return [doc[i] for i in range(sent.start, node)]
 | |
| 
 | |
|     def imm_right_sib(self, doc, node):
 | |
|         for child in list(doc[node].head.children):
 | |
|             if child.i == node + 1:
 | |
|                 return [doc[child.i]]
 | |
|         return []
 | |
| 
 | |
|     def imm_left_sib(self, doc, node):
 | |
|         for child in list(doc[node].head.children):
 | |
|             if child.i == node - 1:
 | |
|                 return [doc[child.i]]
 | |
|         return []
 | |
| 
 | |
|     def right_sib(self, doc, node):
 | |
|         candidate_children = []
 | |
|         for child in list(doc[node].head.children):
 | |
|             if child.i > node:
 | |
|                 candidate_children.append(doc[child.i])
 | |
|         return candidate_children
 | |
| 
 | |
|     def left_sib(self, doc, node):
 | |
|         candidate_children = []
 | |
|         for child in list(doc[node].head.children):
 | |
|             if child.i < node:
 | |
|                 candidate_children.append(doc[child.i])
 | |
|         return candidate_children
 | |
| 
 | |
|     def _normalize_key(self, key):
 | |
|         if isinstance(key, basestring):
 | |
|             return self.vocab.strings.add(key)
 | |
|         else:
 | |
|             return key
 | |
| 
 | |
|     def _get_sent(self, token):
 | |
|         root = (list(token.ancestors) or [token])[-1]
 | |
|         return token.doc[root.left_edge.i:root.right_edge.i + 1]
 | |
| 
 | |
| 
 | |
| def unpickle_matcher(vocab, patterns, callbacks):
 | |
|     matcher = DependencyMatcher(vocab)
 | |
|     for key, pattern in patterns.items():
 | |
|         callback = callbacks.get(key, None)
 | |
|         matcher.add(key, pattern, on_match=callback)
 | |
|     return matcher
 |