From 356af7b0a18fdef2a8761c4b99983fa9445ebe0c Mon Sep 17 00:00:00 2001 From: Suraj Krishnan Rajan Date: Wed, 5 Sep 2018 09:23:21 +0530 Subject: [PATCH] Fix tests --- spacy/matcher.pyx | 247 +++++++++++++++++++++- spacy/tests/matcher/test_matcher_api.py | 48 ++++- spacy/tests/test_matcher.py | 263 ++++++++++++++++++++++++ 3 files changed, 556 insertions(+), 2 deletions(-) create mode 100644 spacy/tests/test_matcher.py diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 970cb8743..e8d567428 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -28,6 +28,8 @@ from .attrs import FLAG42 as I3_ENT from .attrs import FLAG41 as I4_ENT +DELIMITER = '||' + cdef enum action_t: REJECT = 0000 @@ -285,6 +287,8 @@ cdef char get_is_final(PatternStateC state) nogil: cdef char get_quantifier(PatternStateC state) nogil: return state.pattern.quantifier +DEF PADDING = 5 + cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: @@ -417,7 +421,7 @@ cdef class Matcher: key (unicode): The match ID. on_match (callable): Callback executed on match. - *patterns (list): List of token descritions. + *patterns (list): List of token descriptions. """ for pattern in patterns: if len(pattern) == 0: @@ -611,6 +615,7 @@ cdef class PhraseMatcher: self.phrase_ids.set(phrase_hash, ent_id) def __call__(self, Doc doc): + """Find all sequences matching the supplied patterns on the `Doc`. doc (Doc): The document to match over. @@ -673,3 +678,243 @@ cdef class PhraseMatcher: return None else: return ent_id + + +cdef class DependencyTreeMatcher: + """Match dependency parse tree based on pattern rules.""" + cdef Pool mem + cdef readonly Vocab vocab + cdef readonly Matcher token_matcher + cdef public object _patterns + cdef public object _keys_to_token + cdef public object _root + cdef public object _entities + cdef public object _callbacks + cdef public object _nodes + cdef public object _tree + + def __init__(self, vocab): + """Create the DependencyTreeMatcher. + + vocab (Vocab): The vocabulary object, which must be shared with the + documents the matcher will operate on. + RETURNS (DependencyTreeMatcher): The newly constructed object. + """ + size = 20 + self.token_matcher = Matcher(vocab) + self._keys_to_token = {} + self._patterns = {} + self._root = {} + self._nodes = {} + self._tree = {} + self._entities = {} + self._callbacks = {} + self.vocab = vocab + self.mem = Pool() + + def __reduce__(self): + data = (self.vocab, self._patterns,self._tree, self._callbacks) + return (unpickle_matcher, data, None, None) + + def __len__(self): + """Get the number of rules, which are edges ,added to the dependency tree matcher. + + RETURNS (int): The number of rules. + """ + return len(self._patterns) + + def __contains__(self, key): + """Check whether the matcher contains rules for a match ID. + + key (unicode): The match ID. + RETURNS (bool): Whether the matcher contains rules for this match ID. + """ + return self._normalize_key(key) in self._patterns + + + def add(self, key, on_match, *patterns): + # TODO : validations + # 1. check if input pattern is connected + # 2. check if pattern format is correct + # 3. check if atleast one root node is present + # 4. check if node names are not repeated + # 5. check if each node has only one head + + for pattern in patterns: + if len(pattern) == 0: + raise ValueError(Errors.E012.format(key=key)) + + key = self._normalize_key(key) + + _patterns = [] + for pattern in patterns: + token_patterns = [] + for i in range(len(pattern)): + token_pattern = [pattern[i]['PATTERN']] + token_patterns.append(token_pattern) + # self.patterns.append(token_patterns) + _patterns.append(token_patterns) + + self._patterns.setdefault(key, []) + self._callbacks[key] = on_match + self._patterns[key].extend(_patterns) + + # Add each node pattern of all the input patterns individually to the matcher. + # This enables only a single instance of Matcher to be used. + # Multiple adds are required to track each node pattern. + _keys_to_token_list = [] + for i in range(len(_patterns)): + _keys_to_token = {} + # TODO : Better ways to hash edges in pattern? + for j in range(len(_patterns[i])): + k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j)) + self.token_matcher.add(k,None,_patterns[i][j]) + _keys_to_token[k] = j + _keys_to_token_list.append(_keys_to_token) + + self._keys_to_token.setdefault(key, []) + self._keys_to_token[key].extend(_keys_to_token_list) + + _nodes_list = [] + for pattern in patterns: + nodes = {} + for i in range(len(pattern)): + nodes[pattern[i]['SPEC']['NODE_NAME']]=i + _nodes_list.append(nodes) + + self._nodes.setdefault(key, []) + self._nodes[key].extend(_nodes_list) + + # Create an object tree to traverse later on. + # This datastructure enable easy tree pattern match. + # Doc-Token based tree cannot be reused since it is memory heavy and + # tightly coupled with doc + self.retrieve_tree(patterns,_nodes_list,key) + + def retrieve_tree(self,patterns,_nodes_list,key): + + _heads_list = [] + _root_list = [] + for i in range(len(patterns)): + heads = {} + root = -1 + for j in range(len(patterns[i])): + token_pattern = patterns[i][j] + if('NBOR_RELOP' not in token_pattern['SPEC']): + heads[j] = j + root = j + else: + # TODO: Add semgrex rules + # 1. > + if(token_pattern['SPEC']['NBOR_RELOP'] == '>'): + heads[j] = _nodes_list[i][token_pattern['SPEC']['NBOR_NAME']] + # 2. < + if(token_pattern['SPEC']['NBOR_RELOP'] == '<'): + heads[_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]] = j + + _heads_list.append(heads) + _root_list.append(root) + + _tree_list = [] + for i in range(len(patterns)): + tree = {} + for j in range(len(patterns[i])): + if(j == _heads_list[i][j]): + continue + head = _heads_list[i][j] + if(head not in tree): + tree[head] = [] + tree[head].append(j) + _tree_list.append(tree) + + self._tree.setdefault(key, []) + self._tree[key].extend(_tree_list) + + self._root.setdefault(key, []) + self._root[key].extend(_root_list) + + def has_key(self, key): + """Check whether the matcher has a rule with a given key. + + key (string or int): The key to check. + RETURNS (bool): Whether the matcher has the rule. + """ + key = self._normalize_key(key) + return key in self._patterns + + def get(self, key, default=None): + """Retrieve the pattern stored for a key. + + key (unicode or int): The key to retrieve. + RETURNS (tuple): The rule, as an (on_match, patterns) tuple. + """ + key = self._normalize_key(key) + if key not in self._patterns: + return default + return (self._callbacks[key], self._patterns[key]) + + def __call__(self, Doc doc): + matched_trees = [] + + matches = self.token_matcher(doc) + for key in list(self._patterns.keys()): + _patterns_list = self._patterns[key] + _keys_to_token_list = self._keys_to_token[key] + _root_list = self._root[key] + _tree_list = self._tree[key] + _nodes_list = self._nodes[key] + length = len(_patterns_list) + for i in range(length): + _keys_to_token = _keys_to_token_list[i] + _root = _root_list[i] + _tree = _tree_list[i] + _nodes = _nodes_list[i] + + id_to_position = {} + + # This could be taken outside to improve running time..? + for match_id, start, end in matches: + if match_id in _keys_to_token: + if _keys_to_token[match_id] not in id_to_position: + id_to_position[_keys_to_token[match_id]] = [] + id_to_position[_keys_to_token[match_id]].append(start) + + length = len(_nodes) + if _root in id_to_position: + candidates = id_to_position[_root] + for candidate in candidates: + isVisited = {} + self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited) + # to check if the subtree pattern is completely identified + if(len(isVisited) == length): + matched_trees.append((key,list(isVisited))) + + for i, (ent_id, nodes) in enumerate(matched_trees): + on_match = self._callbacks.get(ent_id) + if on_match is not None: + on_match(self, doc, i, matches) + + return matched_trees + + def dfs(self,candidate,root,tree,id_to_position,doc,isVisited): + if(root in id_to_position and candidate in id_to_position[root]): + # color the node since it is valid + isVisited[candidate] = True + candidate_children = doc[candidate].children + for candidate_child in candidate_children: + if root in tree: + for root_child in tree[root]: + self.dfs( + candidate_child.i, + root_child, + tree, + id_to_position, + doc, + isVisited + ) + + def _normalize_key(self, key): + if isinstance(key, basestring): + return self.vocab.strings.add(key) + else: + return key diff --git a/spacy/tests/matcher/test_matcher_api.py b/spacy/tests/matcher/test_matcher_api.py index 18779c32c..6f4919ac5 100644 --- a/spacy/tests/matcher/test_matcher_api.py +++ b/spacy/tests/matcher/test_matcher_api.py @@ -2,8 +2,10 @@ from __future__ import unicode_literals import pytest -from spacy.matcher import Matcher +import re +from spacy.matcher import Matcher, DependencyTreeMatcher from spacy.tokens import Doc +from ..util import get_doc @pytest.fixture @@ -166,3 +168,47 @@ def test_matcher_any_token_operator(en_vocab): assert matches[0] == 'test' assert matches[1] == 'test hello' assert matches[2] == 'test hello world' + + +@pytest.fixture +def text(): + return u"The quick brown fox jumped over the lazy fox" + +@pytest.fixture +def heads(): + return [3,2,1,1,0,-1,2,1,-3] + +@pytest.fixture +def deps(): + return ['det', 'amod', 'amod', 'nsubj', 'prep', 'pobj', 'det', 'amod'] + +@pytest.fixture +def dependency_tree_matcher(en_vocab): + is_brown_yellow = lambda text: bool(re.compile(r'brown|yellow|over').match(text)) + IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow) + pattern1 = [ + {'SPEC': {'NODE_NAME': 'fox'}, 'PATTERN': {'ORTH': 'fox'}}, + {'SPEC': {'NODE_NAME': 'q', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'},'PATTERN': {'LOWER': u'quick'}}, + {'SPEC': {'NODE_NAME': 'r', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}} + ] + + pattern2 = [ + {'SPEC': {'NODE_NAME': 'jumped'}, 'PATTERN': {'ORTH': 'jumped'}}, + {'SPEC': {'NODE_NAME': 'fox', 'NBOR_RELOP': '>', 'NBOR_NAME': 'jumped'},'PATTERN': {'LOWER': u'fox'}}, + {'SPEC': {'NODE_NAME': 'over', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}} + ] + matcher = DependencyTreeMatcher(en_vocab) + matcher.add('pattern1', None, pattern1) + matcher.add('pattern2', None, pattern2) + return matcher + + + +def test_dependency_tree_matcher_compile(dependency_tree_matcher): + assert len(dependency_tree_matcher) == 2 + +def test_dependency_tree_matcher(dependency_tree_matcher,text,heads,deps): + doc = get_doc(dependency_tree_matcher.vocab,text.split(),heads=heads,deps=deps) + matches = dependency_tree_matcher(doc) + assert len(matches) == 2 + diff --git a/spacy/tests/test_matcher.py b/spacy/tests/test_matcher.py new file mode 100644 index 000000000..af2bf1ad0 --- /dev/null +++ b/spacy/tests/test_matcher.py @@ -0,0 +1,263 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from numpy import sort + +from ..matcher import Matcher, PhraseMatcher, DependencyTreeMatcher +from .util import get_doc +from ..tokens import Doc + +import pytest +import re + +@pytest.fixture +def matcher(en_vocab): + rules = { + 'JS': [[{'ORTH': 'JavaScript'}]], + 'GoogleNow': [[{'ORTH': 'Google'}, {'ORTH': 'Now'}]], + 'Java': [[{'LOWER': 'java'}]] + } + matcher = Matcher(en_vocab) + for key, patterns in rules.items(): + matcher.add(key, None, *patterns) + return matcher + +def test_matcher_from_api_docs(en_vocab): + matcher = Matcher(en_vocab) + pattern = [{'ORTH': 'test'}] + assert len(matcher) == 0 + matcher.add('Rule', None, pattern) + assert len(matcher) == 1 + matcher.remove('Rule') + assert 'Rule' not in matcher + matcher.add('Rule', None, pattern) + assert 'Rule' in matcher + on_match, patterns = matcher.get('Rule') + assert len(patterns[0]) + + +def test_matcher_from_usage_docs(en_vocab): + text = "Wow 😀 This is really cool! 😂 😂" + doc = get_doc(en_vocab, words=text.split(' ')) + pos_emoji = [u'😀', u'😃', u'😂', u'🤣', u'😊', u'😍'] + pos_patterns = [[{'ORTH': emoji}] for emoji in pos_emoji] + + def label_sentiment(matcher, doc, i, matches): + match_id, start, end = matches[i] + if doc.vocab.strings[match_id] == 'HAPPY': + doc.sentiment += 0.1 + span = doc[start : end] + token = span.merge() + token.vocab[token.text].norm_ = 'happy emoji' + + matcher = Matcher(en_vocab) + matcher.add('HAPPY', label_sentiment, *pos_patterns) + matches = matcher(doc) + assert doc.sentiment != 0 + assert doc[1].norm_ == 'happy emoji' + + +@pytest.mark.parametrize('words', [["Some", "words"]]) +def test_matcher_init(en_vocab, words): + matcher = Matcher(en_vocab) + doc = get_doc(en_vocab, words) + assert len(matcher) == 0 + assert matcher(doc) == [] + + +def test_matcher_contains(matcher): + matcher.add('TEST', None, [{'ORTH': 'test'}]) + assert 'TEST' in matcher + assert 'TEST2' not in matcher + + +def test_matcher_no_match(matcher): + words = ["I", "like", "cheese", "."] + doc = get_doc(matcher.vocab, words) + assert matcher(doc) == [] + + +def test_matcher_compile(matcher): + assert len(matcher) == 3 + + +def test_matcher_match_start(matcher): + words = ["JavaScript", "is", "good"] + doc = get_doc(matcher.vocab, words) + assert matcher(doc) == [(matcher.vocab.strings['JS'], 0, 1)] + + +def test_matcher_match_end(matcher): + words = ["I", "like", "java"] + doc = get_doc(matcher.vocab, words) + assert matcher(doc) == [(doc.vocab.strings['Java'], 2, 3)] + + +def test_matcher_match_middle(matcher): + words = ["I", "like", "Google", "Now", "best"] + doc = get_doc(matcher.vocab, words) + assert matcher(doc) == [(doc.vocab.strings['GoogleNow'], 2, 4)] + + +def test_matcher_match_multi(matcher): + words = ["I", "like", "Google", "Now", "and", "java", "best"] + doc = get_doc(matcher.vocab, words) + assert matcher(doc) == [(doc.vocab.strings['GoogleNow'], 2, 4), + (doc.vocab.strings['Java'], 5, 6)] + + +def test_matcher_empty_dict(en_vocab): + '''Test matcher allows empty token specs, meaning match on any token.''' + matcher = Matcher(en_vocab) + abc = ["a", "b", "c"] + doc = get_doc(matcher.vocab, abc) + matcher.add('A.C', None, [{'ORTH': 'a'}, {}, {'ORTH': 'c'}]) + matches = matcher(doc) + assert len(matches) == 1 + assert matches[0][1:] == (0, 3) + matcher = Matcher(en_vocab) + matcher.add('A.', None, [{'ORTH': 'a'}, {}]) + matches = matcher(doc) + assert matches[0][1:] == (0, 2) + + +def test_matcher_operator_shadow(en_vocab): + matcher = Matcher(en_vocab) + abc = ["a", "b", "c"] + doc = get_doc(matcher.vocab, abc) + matcher.add('A.C', None, [{'ORTH': 'a'}, + {"IS_ALPHA": True, "OP": "+"}, + {'ORTH': 'c'}]) + matches = matcher(doc) + assert len(matches) == 1 + assert matches[0][1:] == (0, 3) + + +def test_matcher_phrase_matcher(en_vocab): + words = ["Google", "Now"] + doc = get_doc(en_vocab, words) + matcher = PhraseMatcher(en_vocab) + matcher.add('COMPANY', None, doc) + words = ["I", "like", "Google", "Now", "best"] + doc = get_doc(en_vocab, words) + assert len(matcher(doc)) == 1 + + +def test_phrase_matcher_length(en_vocab): + matcher = PhraseMatcher(en_vocab) + assert len(matcher) == 0 + matcher.add('TEST', None, get_doc(en_vocab, ['test'])) + assert len(matcher) == 1 + matcher.add('TEST2', None, get_doc(en_vocab, ['test2'])) + assert len(matcher) == 2 + + +def test_phrase_matcher_contains(en_vocab): + matcher = PhraseMatcher(en_vocab) + matcher.add('TEST', None, get_doc(en_vocab, ['test'])) + assert 'TEST' in matcher + assert 'TEST2' not in matcher + + +def test_matcher_match_zero(matcher): + words1 = 'He said , " some words " ...'.split() + words2 = 'He said , " some three words " ...'.split() + pattern1 = [{'ORTH': '"'}, + {'OP': '!', 'IS_PUNCT': True}, + {'OP': '!', 'IS_PUNCT': True}, + {'ORTH': '"'}] + pattern2 = [{'ORTH': '"'}, + {'IS_PUNCT': True}, + {'IS_PUNCT': True}, + {'IS_PUNCT': True}, + {'ORTH': '"'}] + + matcher.add('Quote', None, pattern1) + doc = get_doc(matcher.vocab, words1) + assert len(matcher(doc)) == 1 + + doc = get_doc(matcher.vocab, words2) + assert len(matcher(doc)) == 0 + matcher.add('Quote', None, pattern2) + assert len(matcher(doc)) == 0 + + +def test_matcher_match_zero_plus(matcher): + words = 'He said , " some words " ...'.split() + pattern = [{'ORTH': '"'}, + {'OP': '*', 'IS_PUNCT': False}, + {'ORTH': '"'}] + matcher.add('Quote', None, pattern) + doc = get_doc(matcher.vocab, words) + assert len(matcher(doc)) == 1 + + +def test_matcher_match_one_plus(matcher): + control = Matcher(matcher.vocab) + control.add('BasicPhilippe', None, [{'ORTH': 'Philippe'}]) + doc = get_doc(control.vocab, ['Philippe', 'Philippe']) + m = control(doc) + assert len(m) == 2 + matcher.add('KleenePhilippe', None, [{'ORTH': 'Philippe', 'OP': '1'}, + {'ORTH': 'Philippe', 'OP': '+'}]) + m = matcher(doc) + assert len(m) == 1 + + +def test_operator_combos(matcher): + cases = [ + ('aaab', 'a a a b', True), + ('aaab', 'a+ b', True), + ('aaab', 'a+ a+ b', True), + ('aaab', 'a+ a+ a b', True), + ('aaab', 'a+ a+ a+ b', True), + ('aaab', 'a+ a a b', True), + ('aaab', 'a+ a a', True), + ('aaab', 'a+', True), + ('aaa', 'a+ b', False), + ('aaa', 'a+ a+ b', False), + ('aaa', 'a+ a+ a+ b', False), + ('aaa', 'a+ a b', False), + ('aaa', 'a+ a a b', False), + ('aaab', 'a+ a a', True), + ('aaab', 'a+', True), + ('aaab', 'a+ a b', True), + ] + for string, pattern_str, result in cases: + matcher = Matcher(matcher.vocab) + doc = get_doc(matcher.vocab, words=list(string)) + pattern = [] + for part in pattern_str.split(): + if part.endswith('+'): + pattern.append({'ORTH': part[0], 'op': '+'}) + else: + pattern.append({'ORTH': part}) + matcher.add('PATTERN', None, pattern) + matches = matcher(doc) + if result: + assert matches, (string, pattern_str) + else: + assert not matches, (string, pattern_str) + + +def test_matcher_end_zero_plus(matcher): + '''Test matcher works when patterns end with * operator. (issue 1450)''' + matcher = Matcher(matcher.vocab) + matcher.add( + "TSTEND", + None, + [ + {'ORTH': "a"}, + {'ORTH': "b", 'OP': "*"} + ] + ) + nlp = lambda string: Doc(matcher.vocab, words=string.split()) + assert len(matcher(nlp(u'a'))) == 1 + assert len(matcher(nlp(u'a b'))) == 1 + assert len(matcher(nlp(u'a b'))) == 1 + assert len(matcher(nlp(u'a c'))) == 1 + assert len(matcher(nlp(u'a b c'))) == 1 + assert len(matcher(nlp(u'a b b c'))) == 1 + assert len(matcher(nlp(u'a b b'))) == 1 + +