mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
💫 Update matcher engine for regex and extensions (#3173)
* Update matcher engine for regex and extensions Add support for matching over arbitrary Python predicate functions, and arbitrary Python attribute getters. This will allow matching over regex patterns, and allow supporting extension attributes. The results of the Python predicate functions are cached, so that we don't call the same predicate function twice for the same token. The extension attributes are fetched into an array for each token in the doc. This should minimise the performance impact of the new features. We still need to wire up these features to the patterns, and test it all. * Work on wiring up extra attributes in matcher * Work on tests for extra matcher attrs * Add support for extension attrs to matcher * Test extension attribute matching * Work on implementing predicate-based match patterns * Get predicates working for set membership * Add test for set membership * Make extensions+predicates work * Test matcher extensions * Cache predicate results better in Matcher * Remove print statement in matcher test * Use srsly to get key for predicates
This commit is contained in:
parent
f407954b27
commit
77ddcf7381
|
@ -1,6 +1,8 @@
|
|||
# cython: infer_types=True
|
||||
# cython: profile=True
|
||||
from __future__ import unicode_literals
|
||||
import re
|
||||
import srsly
|
||||
from libcpp.vector cimport vector
|
||||
from libc.stdint cimport int32_t, uint64_t, uint16_t
|
||||
from preshed.maps cimport PreshMap
|
||||
|
@ -11,9 +13,11 @@ from .structs cimport TokenC
|
|||
from .lexeme cimport attr_id_t
|
||||
from .vocab cimport Vocab
|
||||
from .tokens.doc cimport Doc
|
||||
from .tokens.token cimport Token
|
||||
from .tokens.doc cimport get_token_attr
|
||||
from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
|
||||
from .errors import Errors, TempErrors, Warnings, deprecation_warning
|
||||
from .strings import get_string_id
|
||||
|
||||
from .attrs import IDS
|
||||
from .attrs import FLAG61 as U_ENT
|
||||
|
@ -56,9 +60,17 @@ cdef struct AttrValueC:
|
|||
attr_id_t attr
|
||||
attr_t value
|
||||
|
||||
cdef struct IndexValueC:
|
||||
int32_t index
|
||||
attr_t value
|
||||
|
||||
cdef struct TokenPatternC:
|
||||
AttrValueC* attrs
|
||||
int32_t* py_predicates
|
||||
IndexValueC* extra_attrs
|
||||
int32_t nr_attr
|
||||
int32_t nr_extra_attr
|
||||
int32_t nr_py
|
||||
quantifier_t quantifier
|
||||
hash_t key
|
||||
|
||||
|
@ -75,19 +87,46 @@ cdef struct MatchC:
|
|||
int32_t length
|
||||
|
||||
|
||||
cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
||||
cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
||||
predicates=tuple()):
|
||||
'''Find matches in a doc, with a compiled array of patterns. Matches are
|
||||
returned as a list of (id, start, end) tuples.
|
||||
|
||||
To augment the compiled patterns, we optionally also take two Python lists.
|
||||
|
||||
The "predicates" list contains functions that take a Python list and return a
|
||||
boolean value. It's mostly used for regular expressions.
|
||||
|
||||
The "extra_getters" list contains functions that take a Python list and return
|
||||
an attr ID. It's mostly used for extension attributes.
|
||||
'''
|
||||
cdef vector[PatternStateC] states
|
||||
cdef vector[MatchC] matches
|
||||
cdef PatternStateC state
|
||||
cdef int i, j, nr_extra_attr
|
||||
cdef Pool mem = Pool()
|
||||
# TODO: Prefill this with the extra attribute values.
|
||||
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
|
||||
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
|
||||
if extensions is not None and len(extensions) >= 1:
|
||||
nr_extra_attr = max(extensions.values())
|
||||
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
|
||||
else:
|
||||
nr_extra_attr = 0
|
||||
extra_attr_values = <attr_t*>mem.alloc(doc.length, sizeof(attr_t))
|
||||
for i, token in enumerate(doc):
|
||||
for name, index in extensions.items():
|
||||
value = token._.get(name)
|
||||
if isinstance(value, basestring):
|
||||
value = token.vocab.strings[value]
|
||||
extra_attr_values[i * nr_extra_attr + index] = value
|
||||
# Main loop
|
||||
cdef int i, j
|
||||
cdef int nr_predicate = len(predicates)
|
||||
for i in range(doc.length):
|
||||
for j in range(n):
|
||||
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||
transition_states(states, matches, &doc.c[i], extra_attrs[i])
|
||||
transition_states(states, matches, predicate_cache,
|
||||
doc[i], extra_attr_values, predicates)
|
||||
predicate_cache += nr_predicate
|
||||
extra_attr_values += nr_extra_attr
|
||||
# Handle matches that end in 0-width patterns
|
||||
finish_states(matches, states)
|
||||
output = []
|
||||
|
@ -119,11 +158,17 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
|||
|
||||
|
||||
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
||||
const TokenC* token, const attr_t* extra_attrs) except *:
|
||||
char* cached_py_predicates,
|
||||
Token token, const attr_t* extra_attrs, py_predicates) except *:
|
||||
cdef int q = 0
|
||||
cdef vector[PatternStateC] new_states
|
||||
cdef int nr_predicate = len(py_predicates)
|
||||
for i in range(states.size()):
|
||||
action = get_action(states[i], token, extra_attrs)
|
||||
if states[i].pattern.nr_py != 0:
|
||||
update_predicate_cache(cached_py_predicates,
|
||||
states[i].pattern, token, py_predicates)
|
||||
action = get_action(states[i], token.c, extra_attrs,
|
||||
cached_py_predicates, nr_predicate)
|
||||
if action == REJECT:
|
||||
continue
|
||||
state = states[i]
|
||||
|
@ -140,7 +185,11 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
PatternStateC(pattern=state.pattern+1, start=state.start,
|
||||
length=state.length+1))
|
||||
states[q].pattern += 1
|
||||
action = get_action(states[q], token, extra_attrs)
|
||||
if states[q].pattern.nr_py != 0:
|
||||
update_predicate_cache(cached_py_predicates,
|
||||
states[q].pattern, token, py_predicates)
|
||||
action = get_action(states[q], token.c, extra_attrs,
|
||||
cached_py_predicates, nr_predicate)
|
||||
if action == REJECT:
|
||||
pass
|
||||
elif action == ADVANCE:
|
||||
|
@ -168,6 +217,26 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
|||
states.push_back(new_states[i])
|
||||
|
||||
|
||||
cdef void update_predicate_cache(char* cache,
|
||||
const TokenPatternC* pattern, Token token, predicates):
|
||||
# If the state references any extra predicates, check whether they match.
|
||||
# These are cached, so that we don't call these potentially expensive
|
||||
# Python functions more than we need to.
|
||||
for i in range(pattern.nr_py):
|
||||
index = pattern.py_predicates[i]
|
||||
if cache[index] == 0:
|
||||
predicate = predicates[index]
|
||||
result = predicate(token)
|
||||
if result is True:
|
||||
cache[index] = 1
|
||||
elif result is False:
|
||||
cache[index] = -1
|
||||
elif result is None:
|
||||
pass
|
||||
else:
|
||||
raise ValueError("Unexpected value: %s" % result)
|
||||
|
||||
|
||||
cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *:
|
||||
'''Handle states that end in zero-width patterns.'''
|
||||
cdef PatternStateC state
|
||||
|
@ -184,7 +253,9 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states)
|
|||
state.pattern += 1
|
||||
|
||||
|
||||
cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
|
||||
cdef action_t get_action(PatternStateC state,
|
||||
const TokenC* token, const attr_t* extra_attrs,
|
||||
const char* predicate_matches, int nr_predicate) nogil:
|
||||
'''We need to consider:
|
||||
|
||||
a) Does the token match the specification? [Yes, No]
|
||||
|
@ -244,7 +315,7 @@ cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t*
|
|||
Problem: If a quantifier is matching, we're adding a lot of open partials
|
||||
'''
|
||||
cdef char is_match
|
||||
is_match = get_is_match(state, token, extra_attrs)
|
||||
is_match = get_is_match(state, token, extra_attrs, predicate_matches, nr_predicate)
|
||||
quantifier = get_quantifier(state)
|
||||
is_final = get_is_final(state)
|
||||
if quantifier == ZERO:
|
||||
|
@ -293,13 +364,20 @@ cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t*
|
|||
return RETRY
|
||||
|
||||
|
||||
cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
|
||||
cdef char get_is_match(PatternStateC state,
|
||||
const TokenC* token, const attr_t* extra_attrs,
|
||||
const char* predicate_matches, int nr_predicate) nogil:
|
||||
for i in range(nr_predicate):
|
||||
if predicate_matches[i] == -1:
|
||||
return 0
|
||||
spec = state.pattern
|
||||
for attr in spec.attrs[:spec.nr_attr]:
|
||||
if get_token_attr(token, attr.attr) != attr.value:
|
||||
return 0
|
||||
else:
|
||||
return 1
|
||||
for i in range(spec.nr_extra_attr):
|
||||
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
|
||||
return 0
|
||||
return True
|
||||
|
||||
|
||||
cdef char get_is_final(PatternStateC state) nogil:
|
||||
|
@ -316,17 +394,25 @@ cdef char get_quantifier(PatternStateC state) nogil:
|
|||
DEF PADDING = 5
|
||||
|
||||
|
||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
|
||||
object token_specs) except NULL:
|
||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
|
||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||
cdef int i
|
||||
for i, (quantifier, spec) in enumerate(token_specs):
|
||||
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
||||
pattern[i].quantifier = quantifier
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
||||
pattern[i].nr_attr = len(spec)
|
||||
for j, (attr, value) in enumerate(spec):
|
||||
pattern[i].attrs[j].attr = attr
|
||||
pattern[i].attrs[j].value = value
|
||||
pattern[i].extra_attrs = <IndexValueC*>mem.alloc(len(extensions), sizeof(IndexValueC))
|
||||
for j, (index, value) in enumerate(extensions):
|
||||
pattern[i].extra_attrs[j].index = index
|
||||
pattern[i].extra_attrs[j].value = value
|
||||
pattern[i].nr_extra_attr = len(extensions)
|
||||
pattern[i].py_predicates = <int32_t*>mem.alloc(len(predicates), sizeof(int32_t))
|
||||
for j, index in enumerate(predicates):
|
||||
pattern[i].py_predicates[j] = index
|
||||
pattern[i].nr_py = len(predicates)
|
||||
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
|
||||
i = len(token_specs)
|
||||
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
|
||||
|
@ -345,39 +431,214 @@ cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
|||
raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
|
||||
return id_attr.value
|
||||
|
||||
def _convert_strings(token_specs, string_store):
|
||||
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
|
||||
operators = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
|
||||
'?': (ZERO_ONE,), '1': (ONE,), '!': (ZERO,)}
|
||||
|
||||
def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predicates):
|
||||
"""This function interprets the pattern, converting the various bits of
|
||||
syntactic sugar before we compile it into a struct with init_pattern.
|
||||
|
||||
We need to split the pattern up into three parts:
|
||||
* Normal attribute/value pairs, which are stored on either the token or lexeme,
|
||||
can be handled directly.
|
||||
* Extension attributes are handled specially, as we need to prefetch the
|
||||
values from Python for the doc before we begin matching.
|
||||
* Extra predicates also call Python functions, so we have to create the
|
||||
functions and store them. So we store these specially as well.
|
||||
* Extension attributes that have extra predicates are stored within the
|
||||
extra_predicates.
|
||||
"""
|
||||
tokens = []
|
||||
op = ONE
|
||||
seen_predicates = {}
|
||||
for spec in token_specs:
|
||||
if not spec:
|
||||
# Signifier for 'any token'
|
||||
tokens.append((ONE, [(NULL_ATTR, 0)]))
|
||||
tokens.append((ONE, [(NULL_ATTR, 0)], [], []))
|
||||
continue
|
||||
token = []
|
||||
ops = (ONE,)
|
||||
ops = _get_operators(spec)
|
||||
attr_values = _get_attr_values(spec, string_store)
|
||||
extensions = _get_extensions(spec, string_store, extensions_table)
|
||||
predicates = _get_extra_predicates(spec, extra_predicates, seen_predicates)
|
||||
for op in ops:
|
||||
tokens.append((op, list(attr_values), list(extensions), list(predicates)))
|
||||
return tokens
|
||||
|
||||
|
||||
def _get_attr_values(spec, string_store):
|
||||
attr_values = []
|
||||
for attr, value in spec.items():
|
||||
if isinstance(attr, basestring) and attr.upper() == 'OP':
|
||||
if value in operators:
|
||||
ops = operators[value]
|
||||
else:
|
||||
keys = ', '.join(operators.keys())
|
||||
raise KeyError(Errors.E011.format(op=value, opts=keys))
|
||||
if isinstance(attr, basestring):
|
||||
if attr == '_':
|
||||
continue
|
||||
elif attr.upper() == 'OP':
|
||||
continue
|
||||
if attr.upper() == 'TEXT':
|
||||
attr = 'ORTH'
|
||||
attr = IDS.get(attr.upper())
|
||||
if isinstance(value, basestring):
|
||||
value = string_store.add(value)
|
||||
if isinstance(value, bool):
|
||||
elif isinstance(value, bool):
|
||||
value = int(value)
|
||||
elif isinstance(value, dict):
|
||||
continue
|
||||
if attr is not None:
|
||||
token.append((attr, value))
|
||||
for op in ops:
|
||||
tokens.append((op, token))
|
||||
return tokens
|
||||
attr_values.append((attr, value))
|
||||
return attr_values
|
||||
|
||||
# These predicate helper classes are used to match the REGEX, IN, >= etc
|
||||
# extensions to the matcher introduced in #3173.
|
||||
|
||||
class _RegexPredicate(object):
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False):
|
||||
self.i = i
|
||||
self.attr = attr
|
||||
self.value = re.compile(value)
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
assert self.predicate == 'REGEX'
|
||||
|
||||
def __call__(self, Token token):
|
||||
if self.is_extension:
|
||||
value = token._.get(self.attr)
|
||||
else:
|
||||
value = token.vocab.strings[get_token_attr(token.c, self.attr)]
|
||||
return bool(self.value.search(value))
|
||||
|
||||
|
||||
class _SetMemberPredicate(object):
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False):
|
||||
self.i = i
|
||||
self.attr = attr
|
||||
self.value = set(get_string_id(v) for v in value)
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
assert self.predicate in ('IN', 'NOT_IN')
|
||||
|
||||
def __call__(self, Token token):
|
||||
if self.is_extension:
|
||||
value = get_string_id(token._.get(self.attr))
|
||||
else:
|
||||
value = get_token_attr(token.c, self.attr)
|
||||
if self.predicate == 'IN':
|
||||
return value in self.value
|
||||
else:
|
||||
return value not in self.value
|
||||
|
||||
|
||||
class _ComparisonPredicate(object):
|
||||
def __init__(self, i, attr, value, predicate, is_extension=False):
|
||||
self.i = i
|
||||
self.attr = attr
|
||||
self.value = value
|
||||
self.predicate = predicate
|
||||
self.is_extension = is_extension
|
||||
assert self.predicate in ('==', '!=', '>=', '<=', '>', '<')
|
||||
|
||||
def __call__(self, Token token):
|
||||
if self.is_extension:
|
||||
value = token._.get(self.attr)
|
||||
else:
|
||||
value = get_token_attr(token.c, self.attr)
|
||||
if self.predicate == '==':
|
||||
return value == self.value
|
||||
if self.predicate == '!=':
|
||||
return value != self.value
|
||||
elif self.predicate == '>=':
|
||||
return value >= self.value
|
||||
elif self.predicate == '<=':
|
||||
return value <= self.value
|
||||
elif self.predicate == '>':
|
||||
return value > self.value
|
||||
elif self.predicate == '<':
|
||||
return value < self.value
|
||||
|
||||
|
||||
def _get_extra_predicates(spec, extra_predicates, seen_predicates):
|
||||
predicate_types = {
|
||||
'REGEX': _RegexPredicate,
|
||||
'IN': _SetMemberPredicate,
|
||||
'NOT_IN': _SetMemberPredicate,
|
||||
'==': _ComparisonPredicate,
|
||||
'>=': _ComparisonPredicate,
|
||||
'<=': _ComparisonPredicate,
|
||||
'>': _ComparisonPredicate,
|
||||
'<': _ComparisonPredicate,
|
||||
}
|
||||
output = []
|
||||
for attr, value in spec.items():
|
||||
if isinstance(attr, basestring):
|
||||
if attr == '_':
|
||||
output.extend(
|
||||
_get_extension_extra_predicates(
|
||||
value, extra_predicates, predicate_types,
|
||||
seen_predicates))
|
||||
continue
|
||||
elif attr.upper() == 'OP':
|
||||
continue
|
||||
if attr.upper() == 'TEXT':
|
||||
attr = 'ORTH'
|
||||
attr = IDS.get(attr.upper())
|
||||
if isinstance(value, dict):
|
||||
for type_, cls in predicate_types.items():
|
||||
if type_ in value:
|
||||
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
|
||||
# Don't create a redundant predicates.
|
||||
# This helps with efficiency, as we're caching the results.
|
||||
if key in seen_predicates:
|
||||
output.append(seen_predicates[key])
|
||||
else:
|
||||
predicate = cls(len(extra_predicates), attr, value[type_], type_)
|
||||
extra_predicates.append(predicate)
|
||||
output.append(predicate.i)
|
||||
seen_predicates[key] = predicate.i
|
||||
return output
|
||||
|
||||
|
||||
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
|
||||
seen_predicates):
|
||||
output = []
|
||||
for attr, value in spec.items():
|
||||
if isinstance(value, dict):
|
||||
for type_, cls in predicate_types.items():
|
||||
if type_ in value:
|
||||
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
|
||||
if key in seen_predicates:
|
||||
output.append(seen_predicates[key])
|
||||
else:
|
||||
predicate = cls(len(extra_predicates), attr, value[type_], type_,
|
||||
is_extension=True)
|
||||
extra_predicates.append(predicate)
|
||||
output.append(predicate.i)
|
||||
seen_predicates[key] = predicate.i
|
||||
return output
|
||||
|
||||
|
||||
def _get_operators(spec):
|
||||
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
|
||||
lookup = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
|
||||
'?': (ZERO_ONE,), '1': (ONE,), '!': (ZERO,)}
|
||||
# Fix casing
|
||||
spec = {key.upper(): values for key, values in spec.items()
|
||||
if isinstance(key, basestring)}
|
||||
if 'OP' not in spec:
|
||||
return (ONE,)
|
||||
elif spec['OP'] in lookup:
|
||||
return lookup[spec['OP']]
|
||||
else:
|
||||
keys = ', '.join(lookup.keys())
|
||||
raise KeyError(Errors.E011.format(op=spec['OP'], opts=keys))
|
||||
|
||||
|
||||
def _get_extensions(spec, string_store, name2index):
|
||||
attr_values = []
|
||||
for name, value in spec.get('_', {}).items():
|
||||
if isinstance(value, dict):
|
||||
# Handle predicates (e.g. "IN", in the extra_predicates, not here.
|
||||
continue
|
||||
if isinstance(value, basestring):
|
||||
value = string_store.add(value)
|
||||
if name not in name2index:
|
||||
name2index[name] = len(name2index)
|
||||
attr_values.append((name2index[name], value))
|
||||
return attr_values
|
||||
|
||||
|
||||
cdef class Matcher:
|
||||
|
@ -388,6 +649,8 @@ cdef class Matcher:
|
|||
cdef public object _patterns
|
||||
cdef public object _entities
|
||||
cdef public object _callbacks
|
||||
cdef public object _extensions
|
||||
cdef public object _extra_predicates
|
||||
|
||||
def __init__(self, vocab):
|
||||
"""Create the Matcher.
|
||||
|
@ -396,9 +659,12 @@ cdef class Matcher:
|
|||
documents the matcher will operate on.
|
||||
RETURNS (Matcher): The newly constructed object.
|
||||
"""
|
||||
self._extra_predicates = []
|
||||
self._patterns = {}
|
||||
self._entities = {}
|
||||
self._callbacks = {}
|
||||
self._extensions = {}
|
||||
self._extra_predicates = []
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
|
||||
|
@ -456,7 +722,8 @@ cdef class Matcher:
|
|||
raise ValueError(Errors.E012.format(key=key))
|
||||
key = self._normalize_key(key)
|
||||
for pattern in patterns:
|
||||
specs = _convert_strings(pattern, self.vocab.strings)
|
||||
specs = _preprocess_pattern(pattern, self.vocab.strings,
|
||||
self._extensions, self._extra_predicates)
|
||||
self.patterns.push_back(init_pattern(self.mem, key, specs))
|
||||
self._patterns.setdefault(key, [])
|
||||
self._callbacks[key] = on_match
|
||||
|
@ -520,7 +787,9 @@ cdef class Matcher:
|
|||
describing the matches. A match tuple describes a span
|
||||
`doc[start:end]`. The `label_id` and `key` are both integers.
|
||||
"""
|
||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doc)
|
||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
|
||||
extensions=self._extensions,
|
||||
predicates=self._extra_predicates)
|
||||
for i, (key, start, end) in enumerate(matches):
|
||||
on_match = self._callbacks.get(key, None)
|
||||
if on_match is not None:
|
||||
|
|
|
@ -4,7 +4,7 @@ from __future__ import unicode_literals
|
|||
import pytest
|
||||
import re
|
||||
from spacy.matcher import Matcher, DependencyTreeMatcher
|
||||
from spacy.tokens import Doc
|
||||
from spacy.tokens import Doc, Token
|
||||
from ..util import get_doc
|
||||
|
||||
|
||||
|
@ -179,6 +179,80 @@ def test_matcher_any_token_operator(en_vocab):
|
|||
assert matches[2] == "test hello world"
|
||||
|
||||
|
||||
def test_matcher_extension_attribute(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
Token.set_extension('is_fruit',
|
||||
getter=lambda token: token.text in ('apple', 'banana'), force=True)
|
||||
pattern = [{'ORTH': 'an'}, {'_': {'is_fruit': True}}]
|
||||
matcher.add('HAVING_FRUIT', None, pattern)
|
||||
doc = Doc(en_vocab, words=['an', 'apple'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 1
|
||||
doc = Doc(en_vocab, words=['an', 'aardvark'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
def test_matcher_set_value(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{'ORTH': {'IN': ['an', 'a']}}]
|
||||
matcher.add('A_OR_AN', None, pattern)
|
||||
doc = Doc(en_vocab, words=['an', 'a', 'apple'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
doc = Doc(en_vocab, words=['aardvark'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
def test_matcher_regex(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{'ORTH': {'REGEX': r'(?:a|an)'}}]
|
||||
matcher.add('A_OR_AN', None, pattern)
|
||||
doc = Doc(en_vocab, words=['an', 'a', 'hi'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
doc = Doc(en_vocab, words=['bye'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_matcher_regex_shape(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{'SHAPE': {'REGEX': r'^[^x]+$'}}]
|
||||
matcher.add('NON_ALPHA', None, pattern)
|
||||
doc = Doc(en_vocab, words=['99', 'problems', '!'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
doc = Doc(en_vocab, words=['bye'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
def test_matcher_compare_length(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
pattern = [{'LENGTH': {'>=': 2}}]
|
||||
matcher.add('LENGTH_COMPARE', None, pattern)
|
||||
doc = Doc(en_vocab, words=['a', 'aa', 'aaa'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
doc = Doc(en_vocab, words=['a'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
def test_matcher_extension_set_membership(en_vocab):
|
||||
matcher = Matcher(en_vocab)
|
||||
Token.set_extension('reversed',
|
||||
getter=lambda token: ''.join(reversed(token.text)), force=True)
|
||||
pattern = [{'_': {'reversed': {"IN": ["eyb", "ih"]}}}]
|
||||
matcher.add('REVERSED', None, pattern)
|
||||
doc = Doc(en_vocab, words=['hi', 'bye', 'hello'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 2
|
||||
doc = Doc(en_vocab, words=['aardvark'])
|
||||
matches = matcher(doc)
|
||||
assert len(matches) == 0
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text():
|
||||
return "The quick brown fox jumped over the lazy fox"
|
||||
|
|
Loading…
Reference in New Issue
Block a user