mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Fix tests
This commit is contained in:
parent
4d2d7d5866
commit
356af7b0a1
|
@ -28,6 +28,8 @@ from .attrs import FLAG42 as I3_ENT
|
||||||
from .attrs import FLAG41 as I4_ENT
|
from .attrs import FLAG41 as I4_ENT
|
||||||
|
|
||||||
|
|
||||||
|
DELIMITER = '||'
|
||||||
|
|
||||||
|
|
||||||
cdef enum action_t:
|
cdef enum action_t:
|
||||||
REJECT = 0000
|
REJECT = 0000
|
||||||
|
@ -285,6 +287,8 @@ cdef char get_is_final(PatternStateC state) nogil:
|
||||||
cdef char get_quantifier(PatternStateC state) nogil:
|
cdef char get_quantifier(PatternStateC state) nogil:
|
||||||
return state.pattern.quantifier
|
return state.pattern.quantifier
|
||||||
|
|
||||||
|
DEF PADDING = 5
|
||||||
|
|
||||||
|
|
||||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
||||||
object token_specs) except NULL:
|
object token_specs) except NULL:
|
||||||
|
@ -417,7 +421,7 @@ cdef class Matcher:
|
||||||
|
|
||||||
key (unicode): The match ID.
|
key (unicode): The match ID.
|
||||||
on_match (callable): Callback executed on match.
|
on_match (callable): Callback executed on match.
|
||||||
*patterns (list): List of token descritions.
|
*patterns (list): List of token descriptions.
|
||||||
"""
|
"""
|
||||||
for pattern in patterns:
|
for pattern in patterns:
|
||||||
if len(pattern) == 0:
|
if len(pattern) == 0:
|
||||||
|
@ -611,6 +615,7 @@ cdef class PhraseMatcher:
|
||||||
self.phrase_ids.set(phrase_hash, <void*>ent_id)
|
self.phrase_ids.set(phrase_hash, <void*>ent_id)
|
||||||
|
|
||||||
def __call__(self, Doc doc):
|
def __call__(self, Doc doc):
|
||||||
|
|
||||||
"""Find all sequences matching the supplied patterns on the `Doc`.
|
"""Find all sequences matching the supplied patterns on the `Doc`.
|
||||||
|
|
||||||
doc (Doc): The document to match over.
|
doc (Doc): The document to match over.
|
||||||
|
@ -673,3 +678,243 @@ cdef class PhraseMatcher:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return ent_id
|
return ent_id
|
||||||
|
|
||||||
|
|
||||||
|
cdef class DependencyTreeMatcher:
|
||||||
|
"""Match dependency parse tree based on pattern rules."""
|
||||||
|
cdef Pool mem
|
||||||
|
cdef readonly Vocab vocab
|
||||||
|
cdef readonly Matcher token_matcher
|
||||||
|
cdef public object _patterns
|
||||||
|
cdef public object _keys_to_token
|
||||||
|
cdef public object _root
|
||||||
|
cdef public object _entities
|
||||||
|
cdef public object _callbacks
|
||||||
|
cdef public object _nodes
|
||||||
|
cdef public object _tree
|
||||||
|
|
||||||
|
def __init__(self, vocab):
|
||||||
|
"""Create the DependencyTreeMatcher.
|
||||||
|
|
||||||
|
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||||
|
documents the matcher will operate on.
|
||||||
|
RETURNS (DependencyTreeMatcher): The newly constructed object.
|
||||||
|
"""
|
||||||
|
size = 20
|
||||||
|
self.token_matcher = Matcher(vocab)
|
||||||
|
self._keys_to_token = {}
|
||||||
|
self._patterns = {}
|
||||||
|
self._root = {}
|
||||||
|
self._nodes = {}
|
||||||
|
self._tree = {}
|
||||||
|
self._entities = {}
|
||||||
|
self._callbacks = {}
|
||||||
|
self.vocab = vocab
|
||||||
|
self.mem = Pool()
|
||||||
|
|
||||||
|
def __reduce__(self):
|
||||||
|
data = (self.vocab, self._patterns,self._tree, self._callbacks)
|
||||||
|
return (unpickle_matcher, data, None, None)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
"""Get the number of rules, which are edges ,added to the dependency tree matcher.
|
||||||
|
|
||||||
|
RETURNS (int): The number of rules.
|
||||||
|
"""
|
||||||
|
return len(self._patterns)
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
"""Check whether the matcher contains rules for a match ID.
|
||||||
|
|
||||||
|
key (unicode): The match ID.
|
||||||
|
RETURNS (bool): Whether the matcher contains rules for this match ID.
|
||||||
|
"""
|
||||||
|
return self._normalize_key(key) in self._patterns
|
||||||
|
|
||||||
|
|
||||||
|
def add(self, key, on_match, *patterns):
|
||||||
|
# TODO : validations
|
||||||
|
# 1. check if input pattern is connected
|
||||||
|
# 2. check if pattern format is correct
|
||||||
|
# 3. check if atleast one root node is present
|
||||||
|
# 4. check if node names are not repeated
|
||||||
|
# 5. check if each node has only one head
|
||||||
|
|
||||||
|
for pattern in patterns:
|
||||||
|
if len(pattern) == 0:
|
||||||
|
raise ValueError(Errors.E012.format(key=key))
|
||||||
|
|
||||||
|
key = self._normalize_key(key)
|
||||||
|
|
||||||
|
_patterns = []
|
||||||
|
for pattern in patterns:
|
||||||
|
token_patterns = []
|
||||||
|
for i in range(len(pattern)):
|
||||||
|
token_pattern = [pattern[i]['PATTERN']]
|
||||||
|
token_patterns.append(token_pattern)
|
||||||
|
# self.patterns.append(token_patterns)
|
||||||
|
_patterns.append(token_patterns)
|
||||||
|
|
||||||
|
self._patterns.setdefault(key, [])
|
||||||
|
self._callbacks[key] = on_match
|
||||||
|
self._patterns[key].extend(_patterns)
|
||||||
|
|
||||||
|
# Add each node pattern of all the input patterns individually to the matcher.
|
||||||
|
# This enables only a single instance of Matcher to be used.
|
||||||
|
# Multiple adds are required to track each node pattern.
|
||||||
|
_keys_to_token_list = []
|
||||||
|
for i in range(len(_patterns)):
|
||||||
|
_keys_to_token = {}
|
||||||
|
# TODO : Better ways to hash edges in pattern?
|
||||||
|
for j in range(len(_patterns[i])):
|
||||||
|
k = self._normalize_key(unicode(key)+DELIMITER+unicode(i)+DELIMITER+unicode(j))
|
||||||
|
self.token_matcher.add(k,None,_patterns[i][j])
|
||||||
|
_keys_to_token[k] = j
|
||||||
|
_keys_to_token_list.append(_keys_to_token)
|
||||||
|
|
||||||
|
self._keys_to_token.setdefault(key, [])
|
||||||
|
self._keys_to_token[key].extend(_keys_to_token_list)
|
||||||
|
|
||||||
|
_nodes_list = []
|
||||||
|
for pattern in patterns:
|
||||||
|
nodes = {}
|
||||||
|
for i in range(len(pattern)):
|
||||||
|
nodes[pattern[i]['SPEC']['NODE_NAME']]=i
|
||||||
|
_nodes_list.append(nodes)
|
||||||
|
|
||||||
|
self._nodes.setdefault(key, [])
|
||||||
|
self._nodes[key].extend(_nodes_list)
|
||||||
|
|
||||||
|
# Create an object tree to traverse later on.
|
||||||
|
# This datastructure enable easy tree pattern match.
|
||||||
|
# Doc-Token based tree cannot be reused since it is memory heavy and
|
||||||
|
# tightly coupled with doc
|
||||||
|
self.retrieve_tree(patterns,_nodes_list,key)
|
||||||
|
|
||||||
|
def retrieve_tree(self,patterns,_nodes_list,key):
|
||||||
|
|
||||||
|
_heads_list = []
|
||||||
|
_root_list = []
|
||||||
|
for i in range(len(patterns)):
|
||||||
|
heads = {}
|
||||||
|
root = -1
|
||||||
|
for j in range(len(patterns[i])):
|
||||||
|
token_pattern = patterns[i][j]
|
||||||
|
if('NBOR_RELOP' not in token_pattern['SPEC']):
|
||||||
|
heads[j] = j
|
||||||
|
root = j
|
||||||
|
else:
|
||||||
|
# TODO: Add semgrex rules
|
||||||
|
# 1. >
|
||||||
|
if(token_pattern['SPEC']['NBOR_RELOP'] == '>'):
|
||||||
|
heads[j] = _nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]
|
||||||
|
# 2. <
|
||||||
|
if(token_pattern['SPEC']['NBOR_RELOP'] == '<'):
|
||||||
|
heads[_nodes_list[i][token_pattern['SPEC']['NBOR_NAME']]] = j
|
||||||
|
|
||||||
|
_heads_list.append(heads)
|
||||||
|
_root_list.append(root)
|
||||||
|
|
||||||
|
_tree_list = []
|
||||||
|
for i in range(len(patterns)):
|
||||||
|
tree = {}
|
||||||
|
for j in range(len(patterns[i])):
|
||||||
|
if(j == _heads_list[i][j]):
|
||||||
|
continue
|
||||||
|
head = _heads_list[i][j]
|
||||||
|
if(head not in tree):
|
||||||
|
tree[head] = []
|
||||||
|
tree[head].append(j)
|
||||||
|
_tree_list.append(tree)
|
||||||
|
|
||||||
|
self._tree.setdefault(key, [])
|
||||||
|
self._tree[key].extend(_tree_list)
|
||||||
|
|
||||||
|
self._root.setdefault(key, [])
|
||||||
|
self._root[key].extend(_root_list)
|
||||||
|
|
||||||
|
def has_key(self, key):
|
||||||
|
"""Check whether the matcher has a rule with a given key.
|
||||||
|
|
||||||
|
key (string or int): The key to check.
|
||||||
|
RETURNS (bool): Whether the matcher has the rule.
|
||||||
|
"""
|
||||||
|
key = self._normalize_key(key)
|
||||||
|
return key in self._patterns
|
||||||
|
|
||||||
|
def get(self, key, default=None):
|
||||||
|
"""Retrieve the pattern stored for a key.
|
||||||
|
|
||||||
|
key (unicode or int): The key to retrieve.
|
||||||
|
RETURNS (tuple): The rule, as an (on_match, patterns) tuple.
|
||||||
|
"""
|
||||||
|
key = self._normalize_key(key)
|
||||||
|
if key not in self._patterns:
|
||||||
|
return default
|
||||||
|
return (self._callbacks[key], self._patterns[key])
|
||||||
|
|
||||||
|
def __call__(self, Doc doc):
|
||||||
|
matched_trees = []
|
||||||
|
|
||||||
|
matches = self.token_matcher(doc)
|
||||||
|
for key in list(self._patterns.keys()):
|
||||||
|
_patterns_list = self._patterns[key]
|
||||||
|
_keys_to_token_list = self._keys_to_token[key]
|
||||||
|
_root_list = self._root[key]
|
||||||
|
_tree_list = self._tree[key]
|
||||||
|
_nodes_list = self._nodes[key]
|
||||||
|
length = len(_patterns_list)
|
||||||
|
for i in range(length):
|
||||||
|
_keys_to_token = _keys_to_token_list[i]
|
||||||
|
_root = _root_list[i]
|
||||||
|
_tree = _tree_list[i]
|
||||||
|
_nodes = _nodes_list[i]
|
||||||
|
|
||||||
|
id_to_position = {}
|
||||||
|
|
||||||
|
# This could be taken outside to improve running time..?
|
||||||
|
for match_id, start, end in matches:
|
||||||
|
if match_id in _keys_to_token:
|
||||||
|
if _keys_to_token[match_id] not in id_to_position:
|
||||||
|
id_to_position[_keys_to_token[match_id]] = []
|
||||||
|
id_to_position[_keys_to_token[match_id]].append(start)
|
||||||
|
|
||||||
|
length = len(_nodes)
|
||||||
|
if _root in id_to_position:
|
||||||
|
candidates = id_to_position[_root]
|
||||||
|
for candidate in candidates:
|
||||||
|
isVisited = {}
|
||||||
|
self.dfs(candidate,_root,_tree,id_to_position,doc,isVisited)
|
||||||
|
# to check if the subtree pattern is completely identified
|
||||||
|
if(len(isVisited) == length):
|
||||||
|
matched_trees.append((key,list(isVisited)))
|
||||||
|
|
||||||
|
for i, (ent_id, nodes) in enumerate(matched_trees):
|
||||||
|
on_match = self._callbacks.get(ent_id)
|
||||||
|
if on_match is not None:
|
||||||
|
on_match(self, doc, i, matches)
|
||||||
|
|
||||||
|
return matched_trees
|
||||||
|
|
||||||
|
def dfs(self,candidate,root,tree,id_to_position,doc,isVisited):
|
||||||
|
if(root in id_to_position and candidate in id_to_position[root]):
|
||||||
|
# color the node since it is valid
|
||||||
|
isVisited[candidate] = True
|
||||||
|
candidate_children = doc[candidate].children
|
||||||
|
for candidate_child in candidate_children:
|
||||||
|
if root in tree:
|
||||||
|
for root_child in tree[root]:
|
||||||
|
self.dfs(
|
||||||
|
candidate_child.i,
|
||||||
|
root_child,
|
||||||
|
tree,
|
||||||
|
id_to_position,
|
||||||
|
doc,
|
||||||
|
isVisited
|
||||||
|
)
|
||||||
|
|
||||||
|
def _normalize_key(self, key):
|
||||||
|
if isinstance(key, basestring):
|
||||||
|
return self.vocab.strings.add(key)
|
||||||
|
else:
|
||||||
|
return key
|
||||||
|
|
|
@ -2,8 +2,10 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from spacy.matcher import Matcher
|
import re
|
||||||
|
from spacy.matcher import Matcher, DependencyTreeMatcher
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
|
from ..util import get_doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -166,3 +168,47 @@ def test_matcher_any_token_operator(en_vocab):
|
||||||
assert matches[0] == 'test'
|
assert matches[0] == 'test'
|
||||||
assert matches[1] == 'test hello'
|
assert matches[1] == 'test hello'
|
||||||
assert matches[2] == 'test hello world'
|
assert matches[2] == 'test hello world'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def text():
|
||||||
|
return u"The quick brown fox jumped over the lazy fox"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def heads():
|
||||||
|
return [3,2,1,1,0,-1,2,1,-3]
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def deps():
|
||||||
|
return ['det', 'amod', 'amod', 'nsubj', 'prep', 'pobj', 'det', 'amod']
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dependency_tree_matcher(en_vocab):
|
||||||
|
is_brown_yellow = lambda text: bool(re.compile(r'brown|yellow|over').match(text))
|
||||||
|
IS_BROWN_YELLOW = en_vocab.add_flag(is_brown_yellow)
|
||||||
|
pattern1 = [
|
||||||
|
{'SPEC': {'NODE_NAME': 'fox'}, 'PATTERN': {'ORTH': 'fox'}},
|
||||||
|
{'SPEC': {'NODE_NAME': 'q', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'},'PATTERN': {'LOWER': u'quick'}},
|
||||||
|
{'SPEC': {'NODE_NAME': 'r', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}}
|
||||||
|
]
|
||||||
|
|
||||||
|
pattern2 = [
|
||||||
|
{'SPEC': {'NODE_NAME': 'jumped'}, 'PATTERN': {'ORTH': 'jumped'}},
|
||||||
|
{'SPEC': {'NODE_NAME': 'fox', 'NBOR_RELOP': '>', 'NBOR_NAME': 'jumped'},'PATTERN': {'LOWER': u'fox'}},
|
||||||
|
{'SPEC': {'NODE_NAME': 'over', 'NBOR_RELOP': '>', 'NBOR_NAME': 'fox'}, 'PATTERN': {IS_BROWN_YELLOW: True}}
|
||||||
|
]
|
||||||
|
matcher = DependencyTreeMatcher(en_vocab)
|
||||||
|
matcher.add('pattern1', None, pattern1)
|
||||||
|
matcher.add('pattern2', None, pattern2)
|
||||||
|
return matcher
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_dependency_tree_matcher_compile(dependency_tree_matcher):
|
||||||
|
assert len(dependency_tree_matcher) == 2
|
||||||
|
|
||||||
|
def test_dependency_tree_matcher(dependency_tree_matcher,text,heads,deps):
|
||||||
|
doc = get_doc(dependency_tree_matcher.vocab,text.split(),heads=heads,deps=deps)
|
||||||
|
matches = dependency_tree_matcher(doc)
|
||||||
|
assert len(matches) == 2
|
||||||
|
|
||||||
|
|
263
spacy/tests/test_matcher.py
Normal file
263
spacy/tests/test_matcher.py
Normal file
|
@ -0,0 +1,263 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from numpy import sort
|
||||||
|
|
||||||
|
from ..matcher import Matcher, PhraseMatcher, DependencyTreeMatcher
|
||||||
|
from .util import get_doc
|
||||||
|
from ..tokens import Doc
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import re
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def matcher(en_vocab):
|
||||||
|
rules = {
|
||||||
|
'JS': [[{'ORTH': 'JavaScript'}]],
|
||||||
|
'GoogleNow': [[{'ORTH': 'Google'}, {'ORTH': 'Now'}]],
|
||||||
|
'Java': [[{'LOWER': 'java'}]]
|
||||||
|
}
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
for key, patterns in rules.items():
|
||||||
|
matcher.add(key, None, *patterns)
|
||||||
|
return matcher
|
||||||
|
|
||||||
|
def test_matcher_from_api_docs(en_vocab):
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
pattern = [{'ORTH': 'test'}]
|
||||||
|
assert len(matcher) == 0
|
||||||
|
matcher.add('Rule', None, pattern)
|
||||||
|
assert len(matcher) == 1
|
||||||
|
matcher.remove('Rule')
|
||||||
|
assert 'Rule' not in matcher
|
||||||
|
matcher.add('Rule', None, pattern)
|
||||||
|
assert 'Rule' in matcher
|
||||||
|
on_match, patterns = matcher.get('Rule')
|
||||||
|
assert len(patterns[0])
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_from_usage_docs(en_vocab):
|
||||||
|
text = "Wow 😀 This is really cool! 😂 😂"
|
||||||
|
doc = get_doc(en_vocab, words=text.split(' '))
|
||||||
|
pos_emoji = [u'😀', u'😃', u'😂', u'🤣', u'😊', u'😍']
|
||||||
|
pos_patterns = [[{'ORTH': emoji}] for emoji in pos_emoji]
|
||||||
|
|
||||||
|
def label_sentiment(matcher, doc, i, matches):
|
||||||
|
match_id, start, end = matches[i]
|
||||||
|
if doc.vocab.strings[match_id] == 'HAPPY':
|
||||||
|
doc.sentiment += 0.1
|
||||||
|
span = doc[start : end]
|
||||||
|
token = span.merge()
|
||||||
|
token.vocab[token.text].norm_ = 'happy emoji'
|
||||||
|
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
matcher.add('HAPPY', label_sentiment, *pos_patterns)
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert doc.sentiment != 0
|
||||||
|
assert doc[1].norm_ == 'happy emoji'
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize('words', [["Some", "words"]])
|
||||||
|
def test_matcher_init(en_vocab, words):
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
doc = get_doc(en_vocab, words)
|
||||||
|
assert len(matcher) == 0
|
||||||
|
assert matcher(doc) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_contains(matcher):
|
||||||
|
matcher.add('TEST', None, [{'ORTH': 'test'}])
|
||||||
|
assert 'TEST' in matcher
|
||||||
|
assert 'TEST2' not in matcher
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_no_match(matcher):
|
||||||
|
words = ["I", "like", "cheese", "."]
|
||||||
|
doc = get_doc(matcher.vocab, words)
|
||||||
|
assert matcher(doc) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_compile(matcher):
|
||||||
|
assert len(matcher) == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_match_start(matcher):
|
||||||
|
words = ["JavaScript", "is", "good"]
|
||||||
|
doc = get_doc(matcher.vocab, words)
|
||||||
|
assert matcher(doc) == [(matcher.vocab.strings['JS'], 0, 1)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_match_end(matcher):
|
||||||
|
words = ["I", "like", "java"]
|
||||||
|
doc = get_doc(matcher.vocab, words)
|
||||||
|
assert matcher(doc) == [(doc.vocab.strings['Java'], 2, 3)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_match_middle(matcher):
|
||||||
|
words = ["I", "like", "Google", "Now", "best"]
|
||||||
|
doc = get_doc(matcher.vocab, words)
|
||||||
|
assert matcher(doc) == [(doc.vocab.strings['GoogleNow'], 2, 4)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_match_multi(matcher):
|
||||||
|
words = ["I", "like", "Google", "Now", "and", "java", "best"]
|
||||||
|
doc = get_doc(matcher.vocab, words)
|
||||||
|
assert matcher(doc) == [(doc.vocab.strings['GoogleNow'], 2, 4),
|
||||||
|
(doc.vocab.strings['Java'], 5, 6)]
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_empty_dict(en_vocab):
|
||||||
|
'''Test matcher allows empty token specs, meaning match on any token.'''
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
abc = ["a", "b", "c"]
|
||||||
|
doc = get_doc(matcher.vocab, abc)
|
||||||
|
matcher.add('A.C', None, [{'ORTH': 'a'}, {}, {'ORTH': 'c'}])
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert len(matches) == 1
|
||||||
|
assert matches[0][1:] == (0, 3)
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
matcher.add('A.', None, [{'ORTH': 'a'}, {}])
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert matches[0][1:] == (0, 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_operator_shadow(en_vocab):
|
||||||
|
matcher = Matcher(en_vocab)
|
||||||
|
abc = ["a", "b", "c"]
|
||||||
|
doc = get_doc(matcher.vocab, abc)
|
||||||
|
matcher.add('A.C', None, [{'ORTH': 'a'},
|
||||||
|
{"IS_ALPHA": True, "OP": "+"},
|
||||||
|
{'ORTH': 'c'}])
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert len(matches) == 1
|
||||||
|
assert matches[0][1:] == (0, 3)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_phrase_matcher(en_vocab):
|
||||||
|
words = ["Google", "Now"]
|
||||||
|
doc = get_doc(en_vocab, words)
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add('COMPANY', None, doc)
|
||||||
|
words = ["I", "like", "Google", "Now", "best"]
|
||||||
|
doc = get_doc(en_vocab, words)
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_length(en_vocab):
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
assert len(matcher) == 0
|
||||||
|
matcher.add('TEST', None, get_doc(en_vocab, ['test']))
|
||||||
|
assert len(matcher) == 1
|
||||||
|
matcher.add('TEST2', None, get_doc(en_vocab, ['test2']))
|
||||||
|
assert len(matcher) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_phrase_matcher_contains(en_vocab):
|
||||||
|
matcher = PhraseMatcher(en_vocab)
|
||||||
|
matcher.add('TEST', None, get_doc(en_vocab, ['test']))
|
||||||
|
assert 'TEST' in matcher
|
||||||
|
assert 'TEST2' not in matcher
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_match_zero(matcher):
|
||||||
|
words1 = 'He said , " some words " ...'.split()
|
||||||
|
words2 = 'He said , " some three words " ...'.split()
|
||||||
|
pattern1 = [{'ORTH': '"'},
|
||||||
|
{'OP': '!', 'IS_PUNCT': True},
|
||||||
|
{'OP': '!', 'IS_PUNCT': True},
|
||||||
|
{'ORTH': '"'}]
|
||||||
|
pattern2 = [{'ORTH': '"'},
|
||||||
|
{'IS_PUNCT': True},
|
||||||
|
{'IS_PUNCT': True},
|
||||||
|
{'IS_PUNCT': True},
|
||||||
|
{'ORTH': '"'}]
|
||||||
|
|
||||||
|
matcher.add('Quote', None, pattern1)
|
||||||
|
doc = get_doc(matcher.vocab, words1)
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
|
doc = get_doc(matcher.vocab, words2)
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
matcher.add('Quote', None, pattern2)
|
||||||
|
assert len(matcher(doc)) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_match_zero_plus(matcher):
|
||||||
|
words = 'He said , " some words " ...'.split()
|
||||||
|
pattern = [{'ORTH': '"'},
|
||||||
|
{'OP': '*', 'IS_PUNCT': False},
|
||||||
|
{'ORTH': '"'}]
|
||||||
|
matcher.add('Quote', None, pattern)
|
||||||
|
doc = get_doc(matcher.vocab, words)
|
||||||
|
assert len(matcher(doc)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_match_one_plus(matcher):
|
||||||
|
control = Matcher(matcher.vocab)
|
||||||
|
control.add('BasicPhilippe', None, [{'ORTH': 'Philippe'}])
|
||||||
|
doc = get_doc(control.vocab, ['Philippe', 'Philippe'])
|
||||||
|
m = control(doc)
|
||||||
|
assert len(m) == 2
|
||||||
|
matcher.add('KleenePhilippe', None, [{'ORTH': 'Philippe', 'OP': '1'},
|
||||||
|
{'ORTH': 'Philippe', 'OP': '+'}])
|
||||||
|
m = matcher(doc)
|
||||||
|
assert len(m) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_operator_combos(matcher):
|
||||||
|
cases = [
|
||||||
|
('aaab', 'a a a b', True),
|
||||||
|
('aaab', 'a+ b', True),
|
||||||
|
('aaab', 'a+ a+ b', True),
|
||||||
|
('aaab', 'a+ a+ a b', True),
|
||||||
|
('aaab', 'a+ a+ a+ b', True),
|
||||||
|
('aaab', 'a+ a a b', True),
|
||||||
|
('aaab', 'a+ a a', True),
|
||||||
|
('aaab', 'a+', True),
|
||||||
|
('aaa', 'a+ b', False),
|
||||||
|
('aaa', 'a+ a+ b', False),
|
||||||
|
('aaa', 'a+ a+ a+ b', False),
|
||||||
|
('aaa', 'a+ a b', False),
|
||||||
|
('aaa', 'a+ a a b', False),
|
||||||
|
('aaab', 'a+ a a', True),
|
||||||
|
('aaab', 'a+', True),
|
||||||
|
('aaab', 'a+ a b', True),
|
||||||
|
]
|
||||||
|
for string, pattern_str, result in cases:
|
||||||
|
matcher = Matcher(matcher.vocab)
|
||||||
|
doc = get_doc(matcher.vocab, words=list(string))
|
||||||
|
pattern = []
|
||||||
|
for part in pattern_str.split():
|
||||||
|
if part.endswith('+'):
|
||||||
|
pattern.append({'ORTH': part[0], 'op': '+'})
|
||||||
|
else:
|
||||||
|
pattern.append({'ORTH': part})
|
||||||
|
matcher.add('PATTERN', None, pattern)
|
||||||
|
matches = matcher(doc)
|
||||||
|
if result:
|
||||||
|
assert matches, (string, pattern_str)
|
||||||
|
else:
|
||||||
|
assert not matches, (string, pattern_str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_end_zero_plus(matcher):
|
||||||
|
'''Test matcher works when patterns end with * operator. (issue 1450)'''
|
||||||
|
matcher = Matcher(matcher.vocab)
|
||||||
|
matcher.add(
|
||||||
|
"TSTEND",
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
{'ORTH': "a"},
|
||||||
|
{'ORTH': "b", 'OP': "*"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
nlp = lambda string: Doc(matcher.vocab, words=string.split())
|
||||||
|
assert len(matcher(nlp(u'a'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a c'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b c'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b b c'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b b'))) == 1
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user