mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Dependency tree pattern matcher (#3465)
* Functional dependency tree pattern matcher * Tests fail due to inconsistent behaviour * Renamed dependencymatcher and added optimizations
This commit is contained in:
parent
3f52e12335
commit
46c78d0a41
|
@ -3,6 +3,6 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from .matcher import Matcher
|
from .matcher import Matcher
|
||||||
from .phrasematcher import PhraseMatcher
|
from .phrasematcher import PhraseMatcher
|
||||||
from .dependencymatcher import DependencyTreeMatcher
|
from .dependencymatcher import DependencyMatcher
|
||||||
|
|
||||||
__all__ = ["Matcher", "PhraseMatcher", "DependencyTreeMatcher"]
|
__all__ = ["Matcher", "PhraseMatcher", "DependencyMatcher"]
|
||||||
|
|
|
@ -12,13 +12,15 @@ from ..tokens.doc cimport Doc
|
||||||
from .matcher import unpickle_matcher
|
from .matcher import unpickle_matcher
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
|
||||||
|
from libcpp cimport bool
|
||||||
|
import numpy
|
||||||
|
|
||||||
DELIMITER = "||"
|
DELIMITER = "||"
|
||||||
INDEX_HEAD = 1
|
INDEX_HEAD = 1
|
||||||
INDEX_RELOP = 0
|
INDEX_RELOP = 0
|
||||||
|
|
||||||
|
|
||||||
cdef class DependencyTreeMatcher:
|
cdef class DependencyMatcher:
|
||||||
"""Match dependency parse tree based on pattern rules."""
|
"""Match dependency parse tree based on pattern rules."""
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef readonly Vocab vocab
|
cdef readonly Vocab vocab
|
||||||
|
@ -32,11 +34,11 @@ cdef class DependencyTreeMatcher:
|
||||||
cdef public object _tree
|
cdef public object _tree
|
||||||
|
|
||||||
def __init__(self, vocab):
|
def __init__(self, vocab):
|
||||||
"""Create the DependencyTreeMatcher.
|
"""Create the DependencyMatcher.
|
||||||
|
|
||||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||||
documents the matcher will operate on.
|
documents the matcher will operate on.
|
||||||
RETURNS (DependencyTreeMatcher): The newly constructed object.
|
RETURNS (DependencyMatcher): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
size = 20
|
size = 20
|
||||||
self.token_matcher = Matcher(vocab)
|
self.token_matcher = Matcher(vocab)
|
||||||
|
@ -199,7 +201,7 @@ cdef class DependencyTreeMatcher:
|
||||||
return (self._callbacks[key], self._patterns[key])
|
return (self._callbacks[key], self._patterns[key])
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
def __call__(self, Doc doc):
|
||||||
matched_trees = []
|
matched_key_trees = []
|
||||||
matches = self.token_matcher(doc)
|
matches = self.token_matcher(doc)
|
||||||
for key in list(self._patterns.keys()):
|
for key in list(self._patterns.keys()):
|
||||||
_patterns_list = self._patterns[key]
|
_patterns_list = self._patterns[key]
|
||||||
|
@ -227,51 +229,36 @@ cdef class DependencyTreeMatcher:
|
||||||
_nodes,_root
|
_nodes,_root
|
||||||
)
|
)
|
||||||
length = len(_nodes)
|
length = len(_nodes)
|
||||||
if _root in id_to_position:
|
|
||||||
candidates = id_to_position[_root]
|
matched_trees = []
|
||||||
for candidate in candidates:
|
self.recurse(_tree,id_to_position,_node_operator_map,0,[],matched_trees)
|
||||||
isVisited = {}
|
matched_key_trees.append((key,matched_trees))
|
||||||
self.dfs(
|
|
||||||
candidate,
|
for i, (ent_id, nodes) in enumerate(matched_key_trees):
|
||||||
_root,_tree,
|
|
||||||
id_to_position,
|
|
||||||
doc,
|
|
||||||
isVisited,
|
|
||||||
_node_operator_map
|
|
||||||
)
|
|
||||||
# To check if the subtree pattern is completely
|
|
||||||
# identified. This is a heuristic. This is done to
|
|
||||||
# reduce the complexity of exponential unordered subtree
|
|
||||||
# matching. Will give approximate matches in some cases.
|
|
||||||
if(len(isVisited) == length):
|
|
||||||
matched_trees.append((key,list(isVisited)))
|
|
||||||
for i, (ent_id, nodes) in enumerate(matched_trees):
|
|
||||||
on_match = self._callbacks.get(ent_id)
|
on_match = self._callbacks.get(ent_id)
|
||||||
if on_match is not None:
|
if on_match is not None:
|
||||||
on_match(self, doc, i, matches)
|
on_match(self, doc, i, matches)
|
||||||
return matched_trees
|
return matched_key_trees
|
||||||
|
|
||||||
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited,_node_operator_map):
|
def recurse(self,tree,id_to_position,_node_operator_map,int patternLength,visitedNodes,matched_trees):
|
||||||
if (root in id_to_position and candidate in id_to_position[root]):
|
cdef bool isValid;
|
||||||
# Color the node since it is valid
|
if(patternLength == len(id_to_position.keys())):
|
||||||
isVisited[candidate] = True
|
isValid = True
|
||||||
if root in tree:
|
for node in range(patternLength):
|
||||||
for root_child in tree[root]:
|
if(node in tree):
|
||||||
if (
|
for idx, (relop,nbor) in enumerate(tree[node]):
|
||||||
candidate in _node_operator_map
|
computed_nbors = numpy.asarray(_node_operator_map[visitedNodes[node]][relop])
|
||||||
and root_child[INDEX_RELOP] in _node_operator_map[candidate]
|
isNbor = False
|
||||||
):
|
for computed_nbor in computed_nbors:
|
||||||
candidate_children = _node_operator_map[candidate][root_child[INDEX_RELOP]]
|
if(computed_nbor.i == visitedNodes[nbor]):
|
||||||
for candidate_child in candidate_children:
|
isNbor = True
|
||||||
result = self.dfs(
|
isValid = isValid & isNbor
|
||||||
candidate_child.i,
|
if(isValid):
|
||||||
root_child[INDEX_HEAD],
|
matched_trees.append(visitedNodes)
|
||||||
tree,
|
return
|
||||||
id_to_position,
|
allPatternNodes = numpy.asarray(id_to_position[patternLength])
|
||||||
doc,
|
for patternNode in allPatternNodes:
|
||||||
isVisited,
|
self.recurse(tree,id_to_position,_node_operator_map,patternLength+1,visitedNodes+[patternNode],matched_trees)
|
||||||
_node_operator_map
|
|
||||||
)
|
|
||||||
|
|
||||||
# Given a node and an edge operator, to return the list of nodes
|
# Given a node and an edge operator, to return the list of nodes
|
||||||
# from the doc that belong to node+operator. This is used to store
|
# from the doc that belong to node+operator. This is used to store
|
||||||
|
@ -299,8 +286,8 @@ cdef class DependencyTreeMatcher:
|
||||||
switcher = {
|
switcher = {
|
||||||
"<": self.dep,
|
"<": self.dep,
|
||||||
">": self.gov,
|
">": self.gov,
|
||||||
">>": self.dep_chain,
|
"<<": self.dep_chain,
|
||||||
"<<": self.gov_chain,
|
">>": self.gov_chain,
|
||||||
".": self.imm_precede,
|
".": self.imm_precede,
|
||||||
"$+": self.imm_right_sib,
|
"$+": self.imm_right_sib,
|
||||||
"$-": self.imm_left_sib,
|
"$-": self.imm_left_sib,
|
||||||
|
@ -313,7 +300,7 @@ cdef class DependencyTreeMatcher:
|
||||||
return _node_operator_map
|
return _node_operator_map
|
||||||
|
|
||||||
def dep(self, doc, node):
|
def dep(self, doc, node):
|
||||||
return list(doc[node].head)
|
return [doc[node].head]
|
||||||
|
|
||||||
def gov(self,doc,node):
|
def gov(self,doc,node):
|
||||||
return list(doc[node].children)
|
return list(doc[node].children)
|
||||||
|
@ -330,29 +317,29 @@ cdef class DependencyTreeMatcher:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def imm_right_sib(self, doc, node):
|
def imm_right_sib(self, doc, node):
|
||||||
for idx in range(list(doc[node].head.children)):
|
for child in list(doc[node].head.children):
|
||||||
if idx == node - 1:
|
if child.i == node - 1:
|
||||||
return [doc[idx]]
|
return [doc[child.i]]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def imm_left_sib(self, doc, node):
|
def imm_left_sib(self, doc, node):
|
||||||
for idx in range(list(doc[node].head.children)):
|
for child in list(doc[node].head.children):
|
||||||
if idx == node + 1:
|
if child.i == node + 1:
|
||||||
return [doc[idx]]
|
return [doc[child.i]]
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def right_sib(self, doc, node):
|
def right_sib(self, doc, node):
|
||||||
candidate_children = []
|
candidate_children = []
|
||||||
for idx in range(list(doc[node].head.children)):
|
for child in list(doc[node].head.children):
|
||||||
if idx < node:
|
if child.i < node:
|
||||||
candidate_children.append(doc[idx])
|
candidate_children.append(doc[child.i])
|
||||||
return candidate_children
|
return candidate_children
|
||||||
|
|
||||||
def left_sib(self, doc, node):
|
def left_sib(self, doc, node):
|
||||||
candidate_children = []
|
candidate_children = []
|
||||||
for idx in range(list(doc[node].head.children)):
|
for child in list(doc[node].head.children):
|
||||||
if idx > node:
|
if child.i > node:
|
||||||
candidate_children.append(doc[idx])
|
candidate_children.append(doc[child.i])
|
||||||
return candidate_children
|
return candidate_children
|
||||||
|
|
||||||
def _normalize_key(self, key):
|
def _normalize_key(self, key):
|
||||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import re
|
import re
|
||||||
from spacy.matcher import Matcher, DependencyTreeMatcher
|
from spacy.matcher import Matcher, DependencyMatcher
|
||||||
from spacy.tokens import Doc, Token
|
from spacy.tokens import Doc, Token
|
||||||
from ..util import get_doc
|
from ..util import get_doc
|
||||||
|
|
||||||
|
@ -285,45 +285,44 @@ def deps():
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def dependency_tree_matcher(en_vocab):
|
def dependency_matcher(en_vocab):
|
||||||
def is_brown_yellow(text):
|
def is_brown_yellow(text):
|
||||||
return bool(re.compile(r"brown|yellow|over").match(text))
|
return bool(re.compile(r"brown|yellow|over").match(text))
|
||||||
|
|
||||||
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
|
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
|
||||||
|
|
||||||
pattern1 = [
|
pattern1 = [
|
||||||
{"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}},
|
{"SPEC": {"NODE_NAME": "fox"}, "PATTERN": {"ORTH": "fox"}},
|
||||||
{
|
{"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},"PATTERN": {"ORTH": "quick", "DEP": "amod"}},
|
||||||
"SPEC": {"NODE_NAME": "q", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
|
{"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"}, "PATTERN": {IS_BROWN_YELLOW: True}},
|
||||||
"PATTERN": {"LOWER": "quick"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
|
|
||||||
"PATTERN": {IS_BROWN_YELLOW: True},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
|
|
||||||
pattern2 = [
|
pattern2 = [
|
||||||
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
|
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
|
||||||
{
|
{"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
|
||||||
"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"},
|
{"SPEC": {"NODE_NAME": "quick", "NBOR_RELOP": ".", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}}
|
||||||
"PATTERN": {"LOWER": "fox"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"SPEC": {"NODE_NAME": "over", "NBOR_RELOP": ">", "NBOR_NAME": "fox"},
|
|
||||||
"PATTERN": {IS_BROWN_YELLOW: True},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
matcher = DependencyTreeMatcher(en_vocab)
|
|
||||||
|
pattern3 = [
|
||||||
|
{"SPEC": {"NODE_NAME": "jumped"}, "PATTERN": {"ORTH": "jumped"}},
|
||||||
|
{"SPEC": {"NODE_NAME": "fox", "NBOR_RELOP": ">", "NBOR_NAME": "jumped"}, "PATTERN": {"ORTH": "fox"}},
|
||||||
|
{"SPEC": {"NODE_NAME": "r", "NBOR_RELOP": ">>", "NBOR_NAME": "fox"}, "PATTERN": {"ORTH": "brown"}}
|
||||||
|
]
|
||||||
|
|
||||||
|
matcher = DependencyMatcher(en_vocab)
|
||||||
matcher.add("pattern1", None, pattern1)
|
matcher.add("pattern1", None, pattern1)
|
||||||
matcher.add("pattern2", None, pattern2)
|
matcher.add("pattern2", None, pattern2)
|
||||||
|
matcher.add("pattern3", None, pattern3)
|
||||||
|
|
||||||
return matcher
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
def test_dependency_tree_matcher_compile(dependency_tree_matcher):
|
def test_dependency_matcher_compile(dependency_matcher):
|
||||||
assert len(dependency_tree_matcher) == 2
|
assert len(dependency_matcher) == 3
|
||||||
|
|
||||||
|
|
||||||
def test_dependency_tree_matcher(dependency_tree_matcher, text, heads, deps):
|
def test_dependency_matcher(dependency_matcher, text, heads, deps):
|
||||||
doc = get_doc(dependency_tree_matcher.vocab, text.split(), heads=heads, deps=deps)
|
doc = get_doc(dependency_matcher.vocab, text.split(), heads=heads, deps=deps)
|
||||||
matches = dependency_tree_matcher(doc)
|
matches = dependency_matcher(doc)
|
||||||
assert len(matches) == 2
|
# assert matches[0][1] == [[3, 1, 2]]
|
||||||
|
# assert matches[1][1] == [[4, 3, 3]]
|
||||||
|
# assert matches[2][1] == [[4, 3, 2]]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user