💫 Update matcher engine for regex and extensions (#3173)

* Update matcher engine for regex and extensions

Add support for matching over arbitrary Python predicate functions, and
arbitrary Python attribute getters. This will allow matching over regex
patterns, and allow supporting extension attributes.

The results of the Python predicate functions are cached, so that we don't
call the same predicate function twice for the same token. The extension
attributes are fetched into an array for each token in the doc. This
should minimise the performance impact of the new features.

We still need to wire up these features to the patterns, and test it
all.

* Work on wiring up extra attributes in matcher

* Work on tests for extra matcher attrs

* Add support for extension attrs to matcher

* Test extension attribute matching

* Work on implementing predicate-based match patterns

* Get predicates working for set membership

* Add test for set membership

* Make extensions+predicates work

* Test matcher extensions

* Cache predicate results better in Matcher

* Remove print statement in matcher test

* Use srsly to get key for predicates
This commit is contained in:
Matthew Honnibal 2019-01-21 13:23:15 +01:00 committed by GitHub
parent f407954b27
commit 77ddcf7381
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 388 additions and 45 deletions

View File

@ -1,6 +1,8 @@
# cython: infer_types=True
# cython: profile=True
from __future__ import unicode_literals
import re
import srsly
from libcpp.vector cimport vector
from libc.stdint cimport int32_t, uint64_t, uint16_t
from preshed.maps cimport PreshMap
@ -11,9 +13,11 @@ from .structs cimport TokenC
from .lexeme cimport attr_id_t
from .vocab cimport Vocab
from .tokens.doc cimport Doc
from .tokens.token cimport Token
from .tokens.doc cimport get_token_attr
from .attrs cimport ID, attr_id_t, NULL_ATTR, ORTH
from .errors import Errors, TempErrors, Warnings, deprecation_warning
from .strings import get_string_id
from .attrs import IDS
from .attrs import FLAG61 as U_ENT
@ -56,9 +60,17 @@ cdef struct AttrValueC:
attr_id_t attr
attr_t value
cdef struct IndexValueC:
int32_t index
attr_t value
cdef struct TokenPatternC:
AttrValueC* attrs
int32_t* py_predicates
IndexValueC* extra_attrs
int32_t nr_attr
int32_t nr_extra_attr
int32_t nr_py
quantifier_t quantifier
hash_t key
@ -75,19 +87,46 @@ cdef struct MatchC:
int32_t length
cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
predicates=tuple()):
'''Find matches in a doc, with a compiled array of patterns. Matches are
returned as a list of (id, start, end) tuples.
To augment the compiled patterns, we optionally also take two Python lists.
The "predicates" list contains functions that take a Python list and return a
boolean value. It's mostly used for regular expressions.
The "extra_getters" list contains functions that take a Python list and return
an attr ID. It's mostly used for extension attributes.
'''
cdef vector[PatternStateC] states
cdef vector[MatchC] matches
cdef PatternStateC state
cdef int i, j, nr_extra_attr
cdef Pool mem = Pool()
# TODO: Prefill this with the extra attribute values.
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
if extensions is not None and len(extensions) >= 1:
nr_extra_attr = max(extensions.values())
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
else:
nr_extra_attr = 0
extra_attr_values = <attr_t*>mem.alloc(doc.length, sizeof(attr_t))
for i, token in enumerate(doc):
for name, index in extensions.items():
value = token._.get(name)
if isinstance(value, basestring):
value = token.vocab.strings[value]
extra_attr_values[i * nr_extra_attr + index] = value
# Main loop
cdef int i, j
cdef int nr_predicate = len(predicates)
for i in range(doc.length):
for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0))
transition_states(states, matches, &doc.c[i], extra_attrs[i])
transition_states(states, matches, predicate_cache,
doc[i], extra_attr_values, predicates)
predicate_cache += nr_predicate
extra_attr_values += nr_extra_attr
# Handle matches that end in 0-width patterns
finish_states(matches, states)
output = []
@ -119,11 +158,17 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
const TokenC* token, const attr_t* extra_attrs) except *:
char* cached_py_predicates,
Token token, const attr_t* extra_attrs, py_predicates) except *:
cdef int q = 0
cdef vector[PatternStateC] new_states
cdef int nr_predicate = len(py_predicates)
for i in range(states.size()):
action = get_action(states[i], token, extra_attrs)
if states[i].pattern.nr_py != 0:
update_predicate_cache(cached_py_predicates,
states[i].pattern, token, py_predicates)
action = get_action(states[i], token.c, extra_attrs,
cached_py_predicates, nr_predicate)
if action == REJECT:
continue
state = states[i]
@ -140,7 +185,11 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
PatternStateC(pattern=state.pattern+1, start=state.start,
length=state.length+1))
states[q].pattern += 1
action = get_action(states[q], token, extra_attrs)
if states[q].pattern.nr_py != 0:
update_predicate_cache(cached_py_predicates,
states[q].pattern, token, py_predicates)
action = get_action(states[q], token.c, extra_attrs,
cached_py_predicates, nr_predicate)
if action == REJECT:
pass
elif action == ADVANCE:
@ -168,6 +217,26 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
states.push_back(new_states[i])
cdef void update_predicate_cache(char* cache,
const TokenPatternC* pattern, Token token, predicates):
# If the state references any extra predicates, check whether they match.
# These are cached, so that we don't call these potentially expensive
# Python functions more than we need to.
for i in range(pattern.nr_py):
index = pattern.py_predicates[i]
if cache[index] == 0:
predicate = predicates[index]
result = predicate(token)
if result is True:
cache[index] = 1
elif result is False:
cache[index] = -1
elif result is None:
pass
else:
raise ValueError("Unexpected value: %s" % result)
cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *:
'''Handle states that end in zero-width patterns.'''
cdef PatternStateC state
@ -184,7 +253,9 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states)
state.pattern += 1
cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
cdef action_t get_action(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs,
const char* predicate_matches, int nr_predicate) nogil:
'''We need to consider:
a) Does the token match the specification? [Yes, No]
@ -244,7 +315,7 @@ cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t*
Problem: If a quantifier is matching, we're adding a lot of open partials
'''
cdef char is_match
is_match = get_is_match(state, token, extra_attrs)
is_match = get_is_match(state, token, extra_attrs, predicate_matches, nr_predicate)
quantifier = get_quantifier(state)
is_final = get_is_final(state)
if quantifier == ZERO:
@ -293,13 +364,20 @@ cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t*
return RETRY
cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
cdef char get_is_match(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs,
const char* predicate_matches, int nr_predicate) nogil:
for i in range(nr_predicate):
if predicate_matches[i] == -1:
return 0
spec = state.pattern
for attr in spec.attrs[:spec.nr_attr]:
if get_token_attr(token, attr.attr) != attr.value:
return 0
else:
return 1
for i in range(spec.nr_extra_attr):
if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]:
return 0
return True
cdef char get_is_final(PatternStateC state) nogil:
@ -316,17 +394,25 @@ cdef char get_quantifier(PatternStateC state) nogil:
DEF PADDING = 5
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id,
object token_specs) except NULL:
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
cdef int i
for i, (quantifier, spec) in enumerate(token_specs):
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
pattern[i].quantifier = quantifier
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
pattern[i].nr_attr = len(spec)
for j, (attr, value) in enumerate(spec):
pattern[i].attrs[j].attr = attr
pattern[i].attrs[j].value = value
pattern[i].extra_attrs = <IndexValueC*>mem.alloc(len(extensions), sizeof(IndexValueC))
for j, (index, value) in enumerate(extensions):
pattern[i].extra_attrs[j].index = index
pattern[i].extra_attrs[j].value = value
pattern[i].nr_extra_attr = len(extensions)
pattern[i].py_predicates = <int32_t*>mem.alloc(len(predicates), sizeof(int32_t))
for j, index in enumerate(predicates):
pattern[i].py_predicates[j] = index
pattern[i].nr_py = len(predicates)
pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0)
i = len(token_specs)
pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC))
@ -345,39 +431,214 @@ cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr))
return id_attr.value
def _convert_strings(token_specs, string_store):
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
operators = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
'?': (ZERO_ONE,), '1': (ONE,), '!': (ZERO,)}
def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predicates):
"""This function interprets the pattern, converting the various bits of
syntactic sugar before we compile it into a struct with init_pattern.
We need to split the pattern up into three parts:
* Normal attribute/value pairs, which are stored on either the token or lexeme,
can be handled directly.
* Extension attributes are handled specially, as we need to prefetch the
values from Python for the doc before we begin matching.
* Extra predicates also call Python functions, so we have to create the
functions and store them. So we store these specially as well.
* Extension attributes that have extra predicates are stored within the
extra_predicates.
"""
tokens = []
op = ONE
seen_predicates = {}
for spec in token_specs:
if not spec:
# Signifier for 'any token'
tokens.append((ONE, [(NULL_ATTR, 0)]))
tokens.append((ONE, [(NULL_ATTR, 0)], [], []))
continue
token = []
ops = (ONE,)
ops = _get_operators(spec)
attr_values = _get_attr_values(spec, string_store)
extensions = _get_extensions(spec, string_store, extensions_table)
predicates = _get_extra_predicates(spec, extra_predicates, seen_predicates)
for op in ops:
tokens.append((op, list(attr_values), list(extensions), list(predicates)))
return tokens
def _get_attr_values(spec, string_store):
attr_values = []
for attr, value in spec.items():
if isinstance(attr, basestring) and attr.upper() == 'OP':
if value in operators:
ops = operators[value]
else:
keys = ', '.join(operators.keys())
raise KeyError(Errors.E011.format(op=value, opts=keys))
if isinstance(attr, basestring):
if attr == '_':
continue
elif attr.upper() == 'OP':
continue
if attr.upper() == 'TEXT':
attr = 'ORTH'
attr = IDS.get(attr.upper())
if isinstance(value, basestring):
value = string_store.add(value)
if isinstance(value, bool):
elif isinstance(value, bool):
value = int(value)
elif isinstance(value, dict):
continue
if attr is not None:
token.append((attr, value))
for op in ops:
tokens.append((op, token))
return tokens
attr_values.append((attr, value))
return attr_values
# These predicate helper classes are used to match the REGEX, IN, >= etc
# extensions to the matcher introduced in #3173.
class _RegexPredicate(object):
def __init__(self, i, attr, value, predicate, is_extension=False):
self.i = i
self.attr = attr
self.value = re.compile(value)
self.predicate = predicate
self.is_extension = is_extension
assert self.predicate == 'REGEX'
def __call__(self, Token token):
if self.is_extension:
value = token._.get(self.attr)
else:
value = token.vocab.strings[get_token_attr(token.c, self.attr)]
return bool(self.value.search(value))
class _SetMemberPredicate(object):
def __init__(self, i, attr, value, predicate, is_extension=False):
self.i = i
self.attr = attr
self.value = set(get_string_id(v) for v in value)
self.predicate = predicate
self.is_extension = is_extension
assert self.predicate in ('IN', 'NOT_IN')
def __call__(self, Token token):
if self.is_extension:
value = get_string_id(token._.get(self.attr))
else:
value = get_token_attr(token.c, self.attr)
if self.predicate == 'IN':
return value in self.value
else:
return value not in self.value
class _ComparisonPredicate(object):
def __init__(self, i, attr, value, predicate, is_extension=False):
self.i = i
self.attr = attr
self.value = value
self.predicate = predicate
self.is_extension = is_extension
assert self.predicate in ('==', '!=', '>=', '<=', '>', '<')
def __call__(self, Token token):
if self.is_extension:
value = token._.get(self.attr)
else:
value = get_token_attr(token.c, self.attr)
if self.predicate == '==':
return value == self.value
if self.predicate == '!=':
return value != self.value
elif self.predicate == '>=':
return value >= self.value
elif self.predicate == '<=':
return value <= self.value
elif self.predicate == '>':
return value > self.value
elif self.predicate == '<':
return value < self.value
def _get_extra_predicates(spec, extra_predicates, seen_predicates):
predicate_types = {
'REGEX': _RegexPredicate,
'IN': _SetMemberPredicate,
'NOT_IN': _SetMemberPredicate,
'==': _ComparisonPredicate,
'>=': _ComparisonPredicate,
'<=': _ComparisonPredicate,
'>': _ComparisonPredicate,
'<': _ComparisonPredicate,
}
output = []
for attr, value in spec.items():
if isinstance(attr, basestring):
if attr == '_':
output.extend(
_get_extension_extra_predicates(
value, extra_predicates, predicate_types,
seen_predicates))
continue
elif attr.upper() == 'OP':
continue
if attr.upper() == 'TEXT':
attr = 'ORTH'
attr = IDS.get(attr.upper())
if isinstance(value, dict):
for type_, cls in predicate_types.items():
if type_ in value:
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
# Don't create a redundant predicates.
# This helps with efficiency, as we're caching the results.
if key in seen_predicates:
output.append(seen_predicates[key])
else:
predicate = cls(len(extra_predicates), attr, value[type_], type_)
extra_predicates.append(predicate)
output.append(predicate.i)
seen_predicates[key] = predicate.i
return output
def _get_extension_extra_predicates(spec, extra_predicates, predicate_types,
seen_predicates):
output = []
for attr, value in spec.items():
if isinstance(value, dict):
for type_, cls in predicate_types.items():
if type_ in value:
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
if key in seen_predicates:
output.append(seen_predicates[key])
else:
predicate = cls(len(extra_predicates), attr, value[type_], type_,
is_extension=True)
extra_predicates.append(predicate)
output.append(predicate.i)
seen_predicates[key] = predicate.i
return output
def _get_operators(spec):
# Support 'syntactic sugar' operator '+', as combination of ONE, ZERO_PLUS
lookup = {'*': (ZERO_PLUS,), '+': (ONE, ZERO_PLUS),
'?': (ZERO_ONE,), '1': (ONE,), '!': (ZERO,)}
# Fix casing
spec = {key.upper(): values for key, values in spec.items()
if isinstance(key, basestring)}
if 'OP' not in spec:
return (ONE,)
elif spec['OP'] in lookup:
return lookup[spec['OP']]
else:
keys = ', '.join(lookup.keys())
raise KeyError(Errors.E011.format(op=spec['OP'], opts=keys))
def _get_extensions(spec, string_store, name2index):
attr_values = []
for name, value in spec.get('_', {}).items():
if isinstance(value, dict):
# Handle predicates (e.g. "IN", in the extra_predicates, not here.
continue
if isinstance(value, basestring):
value = string_store.add(value)
if name not in name2index:
name2index[name] = len(name2index)
attr_values.append((name2index[name], value))
return attr_values
cdef class Matcher:
@ -388,6 +649,8 @@ cdef class Matcher:
cdef public object _patterns
cdef public object _entities
cdef public object _callbacks
cdef public object _extensions
cdef public object _extra_predicates
def __init__(self, vocab):
"""Create the Matcher.
@ -396,9 +659,12 @@ cdef class Matcher:
documents the matcher will operate on.
RETURNS (Matcher): The newly constructed object.
"""
self._extra_predicates = []
self._patterns = {}
self._entities = {}
self._callbacks = {}
self._extensions = {}
self._extra_predicates = []
self.vocab = vocab
self.mem = Pool()
@ -456,7 +722,8 @@ cdef class Matcher:
raise ValueError(Errors.E012.format(key=key))
key = self._normalize_key(key)
for pattern in patterns:
specs = _convert_strings(pattern, self.vocab.strings)
specs = _preprocess_pattern(pattern, self.vocab.strings,
self._extensions, self._extra_predicates)
self.patterns.push_back(init_pattern(self.mem, key, specs))
self._patterns.setdefault(key, [])
self._callbacks[key] = on_match
@ -520,7 +787,9 @@ cdef class Matcher:
describing the matches. A match tuple describes a span
`doc[start:end]`. The `label_id` and `key` are both integers.
"""
matches = find_matches(&self.patterns[0], self.patterns.size(), doc)
matches = find_matches(&self.patterns[0], self.patterns.size(), doc,
extensions=self._extensions,
predicates=self._extra_predicates)
for i, (key, start, end) in enumerate(matches):
on_match = self._callbacks.get(key, None)
if on_match is not None:

View File

@ -4,7 +4,7 @@ from __future__ import unicode_literals
import pytest
import re
from spacy.matcher import Matcher, DependencyTreeMatcher
from spacy.tokens import Doc
from spacy.tokens import Doc, Token
from ..util import get_doc
@ -179,6 +179,80 @@ def test_matcher_any_token_operator(en_vocab):
assert matches[2] == "test hello world"
def test_matcher_extension_attribute(en_vocab):
matcher = Matcher(en_vocab)
Token.set_extension('is_fruit',
getter=lambda token: token.text in ('apple', 'banana'), force=True)
pattern = [{'ORTH': 'an'}, {'_': {'is_fruit': True}}]
matcher.add('HAVING_FRUIT', None, pattern)
doc = Doc(en_vocab, words=['an', 'apple'])
matches = matcher(doc)
assert len(matches) == 1
doc = Doc(en_vocab, words=['an', 'aardvark'])
matches = matcher(doc)
assert len(matches) == 0
def test_matcher_set_value(en_vocab):
matcher = Matcher(en_vocab)
pattern = [{'ORTH': {'IN': ['an', 'a']}}]
matcher.add('A_OR_AN', None, pattern)
doc = Doc(en_vocab, words=['an', 'a', 'apple'])
matches = matcher(doc)
assert len(matches) == 2
doc = Doc(en_vocab, words=['aardvark'])
matches = matcher(doc)
assert len(matches) == 0
def test_matcher_regex(en_vocab):
matcher = Matcher(en_vocab)
pattern = [{'ORTH': {'REGEX': r'(?:a|an)'}}]
matcher.add('A_OR_AN', None, pattern)
doc = Doc(en_vocab, words=['an', 'a', 'hi'])
matches = matcher(doc)
assert len(matches) == 2
doc = Doc(en_vocab, words=['bye'])
matches = matcher(doc)
assert len(matches) == 0
def test_matcher_regex_shape(en_vocab):
matcher = Matcher(en_vocab)
pattern = [{'SHAPE': {'REGEX': r'^[^x]+$'}}]
matcher.add('NON_ALPHA', None, pattern)
doc = Doc(en_vocab, words=['99', 'problems', '!'])
matches = matcher(doc)
assert len(matches) == 2
doc = Doc(en_vocab, words=['bye'])
matches = matcher(doc)
assert len(matches) == 0
def test_matcher_compare_length(en_vocab):
matcher = Matcher(en_vocab)
pattern = [{'LENGTH': {'>=': 2}}]
matcher.add('LENGTH_COMPARE', None, pattern)
doc = Doc(en_vocab, words=['a', 'aa', 'aaa'])
matches = matcher(doc)
assert len(matches) == 2
doc = Doc(en_vocab, words=['a'])
matches = matcher(doc)
assert len(matches) == 0
def test_matcher_extension_set_membership(en_vocab):
matcher = Matcher(en_vocab)
Token.set_extension('reversed',
getter=lambda token: ''.join(reversed(token.text)), force=True)
pattern = [{'_': {'reversed': {"IN": ["eyb", "ih"]}}}]
matcher.add('REVERSED', None, pattern)
doc = Doc(en_vocab, words=['hi', 'bye', 'hello'])
matches = matcher(doc)
assert len(matches) == 2
doc = Doc(en_vocab, words=['aardvark'])
matches = matcher(doc)
assert len(matches) == 0
@pytest.fixture
def text():
return "The quick brown fox jumped over the lazy fox"