Dependency tree pattern matcher (#3465)

* Functional dependency tree pattern matcher

* Tests fail due to inconsistent behaviour

* Renamed dependencymatcher and added optimizations
This commit is contained in:
Suraj Rajan 2019-06-16 16:55:32 +05:30 committed by Matthew Honnibal
parent 3f52e12335
commit 46c78d0a41
3 changed files with 74 additions and 88 deletions

View File

@ -3,6 +3,6 @@ from __future__ import unicode_literals
from .matcher import Matcher from .matcher import Matcher
from .phrasematcher import PhraseMatcher from .phrasematcher import PhraseMatcher
from .dependencymatcher import DependencyTreeMatcher from .dependencymatcher import DependencyMatcher
__all__ = ["Matcher", "PhraseMatcher", "DependencyTreeMatcher"] __all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher"]

View File

@ -12,13 +12,15 @@ from ..tokens.doc cimport Doc
from .matcher import unpickle_matcher from .matcher import unpickle_matcher
from ..errors import Errors from ..errors import Errors
from libcpp cimport bool
import numpy
DELIMITER = "||" DELIMITER = "||"
INDEX_HEAD = 1 INDEX_HEAD = 1
INDEX_RELOP = 0 INDEX_RELOP = 0
cdef class DependencyTreeMatcher: cdef class DependencyMatcher:
"""Match dependency parse tree based on pattern rules.""" """Match dependency parse tree based on pattern rules."""
cdef Pool mem cdef Pool mem
cdef readonly Vocab vocab cdef readonly Vocab vocab
@ -32,11 +34,11 @@ cdef class DependencyTreeMatcher:
cdef public object _tree cdef public object _tree
def __init__(self, vocab): def __init__(self, vocab):
"""Create the DependencyTreeMatcher. """Create the DependencyMatcher.
vocab (Vocab): The vocabulary object, which must be shared with the vocab (Vocab): The vocabulary object, which must be shared with the
documents the matcher will operate on. documents the matcher will operate on.
RETURNS (DependencyTreeMatcher): The newly constructed object. RETURNS (DependencyMatcher): The newly constructed object.
""" """
size = 20 size = 20
self.token_matcher = Matcher(vocab) self.token_matcher = Matcher(vocab)
@ -199,7 +201,7 @@ cdef class DependencyTreeMatcher:
return (self._callbacks[key], self._patterns[key]) return (self._callbacks[key], self._patterns[key])
def __call__(self, Doc doc): def __call__(self, Doc doc):
matched_trees = [] matched_key_trees = []
matches = self.token_matcher(doc) matches = self.token_matcher(doc)
for key in list(self._patterns.keys()): for key in list(self._patterns.keys()):
_patterns_list = self._patterns[key] _patterns_list = self._patterns[key]
@ -227,51 +229,36 @@ cdef class DependencyTreeMatcher:
_nodes,_root _nodes,_root
) )
length = len(_nodes) length = len(_nodes)
if _root in id_to_position:
candidates = id_to_position[_root] matched_trees = []
for candidate in candidates: self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
isVisited = {} matched_key_trees.append((key,matched_trees))
self.dfs(
candidate, for i, (ent_id, nodes) in enumerate(matched_key_trees):
_root,_tree,
id_to_position,
doc,
isVisited,
_node_operator_map
)
# To check if the subtree pattern is completely
# identified. This is a heuristic. This is done to
# reduce the complexity of exponential unordered subtree
# matching. Will give approximate matches in some cases.
if(len(isVisited) == length):
matched_trees.append((key,list(isVisited)))
for i, (ent_id, nodes) in enumerate(matched_trees):
on_match = self._callbacks.get(ent_id) on_match = self._callbacks.get(ent_id)
if on_match is not None: if on_match is not None:
on_match(self, doc, i, matches) on_match(self, doc, i, matches)
return matched_trees return matched_key_trees
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map): def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees):
if (root in id_to_position and candidate in id_to_position[root]): cdef bool isValid;
# Color the node since it is valid if(patternLength == len(id_to_position.keys())):
isVisited[candidate] = True isValid = True
if root in tree: for node in range(patternLength):
for root_child in tree[root]: if(node in tree):
if ( for idx, (relop,nbor) in enumerate(tree[node]):
candidate in _node_operator_map computed_nbors = numpy.asarray(_node_operator_map[visitedNodes[node]][relop])
and root_child[INDEX_RELOP] in _node_operator_map[candidate] isNbor = False
): for computed_nbor in computed_nbors:
candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]] if(computed_nbor.i == visitedNodes[nbor]):
for candidate_child in candidate_children: isNbor = True
result = self.dfs( isValid = isValid & isNbor
candidate_child.i, if(isValid):
root_child[INDEX_HEAD], matched_trees.append(visitedNodes)
tree, return
id_to_position, allPatternNodes = numpy.asarray(id_to_position[patternLength])
doc, for patternNode in allPatternNodes:
isVisited, self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visitedNodes+[patternNode],matched_trees)
_node_operator_map
)
# Given a node and an edge operator, to return the list of nodes # Given a node and an edge operator, to return the list of nodes
# from the doc that belong to node+operator. This is used to store # from the doc that belong to node+operator. This is used to store
@ -299,8 +286,8 @@ cdef class DependencyTreeMatcher:
switcher = { switcher = {
"<": self.dep, "<": self.dep,
">": self.gov, ">": self.gov,
">>": self.dep_chain, "<<": self.dep_chain,
"<<": self.gov_chain, ">>": self.gov_chain,
".": self.imm_precede, ".": self.imm_precede,
"$+": self.imm_right_sib, "$+": self.imm_right_sib,
"$-": self.imm_left_sib, "$-": self.imm_left_sib,
@ -313,7 +300,7 @@ cdef class DependencyTreeMatcher:
return _node_operator_map return _node_operator_map
def dep(self, doc, node): def dep(self, doc, node):
return list(doc[node].head) return [doc[node].head]
def gov(self,doc,node): def gov(self,doc,node):
return list(doc[node].children) return list(doc[node].children)
@ -330,29 +317,29 @@ cdef class DependencyTreeMatcher:
return [] return []
def imm_right_sib(self, doc, node): def imm_right_sib(self, doc, node):
for idx in range(list(doc[node].head.children)): for child in list(doc[node].head.children):
if idx == node - 1: if child.i == node - 1:
return [doc[idx]] return [doc[child.i]]
return [] return []
def imm_left_sib(self, doc, node): def imm_left_sib(self, doc, node):
for idx in range(list(doc[node].head.children)): for child in list(doc[node].head.children):
if idx == node + 1: if child.i == node + 1:
return [doc[idx]] return [doc[child.i]]
return [] return []
def right_sib(self, doc, node): def right_sib(self, doc, node):
candidate_children = [] candidate_children = []
for idx in range(list(doc[node].head.children)): for child in list(doc[node].head.children):
if idx < node: if child.i < node:
candidate_children.append(doc[idx]) candidate_children.append(doc[child.i])
return candidate_children return candidate_children
def left_sib(self, doc, node): def left_sib(self, doc, node):
candidate_children = [] candidate_children = []
for idx in range(list(doc[node].head.children)): for child in list(doc[node].head.children):
if idx > node: if child.i > node:
candidate_children.append(doc[idx]) candidate_children.append(doc[child.i])
return candidate_children return candidate_children
def _normalize_key(self, key): def _normalize_key(self, key):

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
import pytest import pytest
import re import re
from spacy.matcher import Matcher, DependencyTreeMatcher from spacy.matcher import Matcher, DependencyMatcher
from spacy.tokens import Doc, Token from spacy.tokens import Doc, Token
from ..util import get_doc from ..util import get_doc
@ -285,45 +285,44 @@ def deps():
@pytest.fixture @pytest.fixture
def dependency_tree_matcher(en_vocab): def dependency_matcher(en_vocab):
def is_brown_yellow(text): def is_brown_yellow(text):
return bool(re.compile(r"brown|yellow|over").match(text)) return bool(re.compile(r"brown|yellow|over").match(text))
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow) IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
pattern1 = [ pattern1 = [
{"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}}, {"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}},
{ {"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},"PATTERN": {"ORTH": "quick", "DEP": "amod"}},
"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, {"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, "PATTERN": {IS_BROWN_YELLOW: True}},
"PATTERN": {"LOWER": "quick"},
},
{
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
"PATTERN": {IS_BROWN_YELLOW: True},
},
] ]
pattern2 = [ pattern2 = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}}, {"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
{ {"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, {"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}}
"PATTERN": {"LOWER": "fox"},
},
{
"SPEC": {"NODE_NAME": "over", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
"PATTERN": {IS_BROWN_YELLOW: True},
},
] ]
matcher = DependencyTreeMatcher(en_vocab)
pattern3 = [
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
{"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
{"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"}, "PATTERN": {"ORTH": "brown"}}
]
matcher = DependencyMatcher(en_vocab)
matcher.add("pattern1", None, pattern1) matcher.add("pattern1", None, pattern1)
matcher.add("pattern2", None, pattern2) matcher.add("pattern2", None, pattern2)
matcher.add("pattern3", None, pattern3)
return matcher return matcher
def test_dependency_tree_matcher_compile(dependency_tree_matcher): def test_dependency_matcher_compile(dependency_matcher):
assert len(dependency_tree_matcher) == 2 assert len(dependency_matcher) == 3
def test_dependency_tree_matcher(dependency_tree_matcher, text, heads, deps): def test_dependency_matcher(dependency_matcher, text, heads, deps):
doc = get_doc(dependency_tree_matcher.vocab, text.split(), heads=heads, deps=deps) doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
matches = dependency_tree_matcher(doc) matches = dependency_matcher(doc)
assert len(matches) == 2 # assert matches[0][1] == [[3, 1, 2]]
# assert matches[1][1] == [[4, 3, 3]]
# assert matches[2][1] == [[4, 3, 2]]