💫 Fix bugs in matcher extensions. Closes #1971 (#3301)

* Fix matching on extension attrs and predicates

* Fix detection of match_id when using extension attributes. The match
ID is stored as the last entry in the pattern. We were checking for this
with nr_attr == 0, which didn't account for extension attributes.

* Fix handling of predicates. The wrong count was being passed through,
so even patterns that didn't have a predicate were being checked.

* Fix regex pattern

* Fix matcher set value test
This commit is contained in:
Matthew Honnibal 2019-02-20 21:30:39 +01:00 committed by Ines Montani
parent f73d01aa32
commit 0d1ca15b13
3 changed files with 42 additions and 34 deletions

View File

@ -44,7 +44,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
cdef Pool mem = Pool() cdef Pool mem = Pool()
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char)) predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
if extensions is not None and len(extensions) >= 1: if extensions is not None and len(extensions) >= 1:
nr_extra_attr = max(extensions.values()) nr_extra_attr = max(extensions.values()) + 1
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t)) extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
else: else:
nr_extra_attr = 0 nr_extra_attr = 0
@ -60,9 +60,8 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
for i in range(doc.length): for i in range(doc.length):
for j in range(n): for j in range(n):
states.push_back(PatternStateC(patterns[j], i, 0)) states.push_back(PatternStateC(patterns[j], i, 0))
transition_states(states, matches, predicate_cache, transition_states(states, matches, &predicate_cache[i],
doc[i], extra_attr_values, predicates) doc[i], extra_attr_values, predicates)
predicate_cache += nr_predicate
extra_attr_values += nr_extra_attr extra_attr_values += nr_extra_attr
# Handle matches that end in 0-width patterns # Handle matches that end in 0-width patterns
finish_states(matches, states) finish_states(matches, states)
@ -74,6 +73,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
matches[i].start, matches[i].start,
matches[i].start+matches[i].length matches[i].start+matches[i].length
) )
# We need to deduplicate, because we could otherwise arrive at the same # We need to deduplicate, because we could otherwise arrive at the same
# match through two paths, e.g. .?.? matching 'a'. Are we matching the # match through two paths, e.g. .?.? matching 'a'. Are we matching the
# first .?, or the second .? -- it doesn't matter, it's just one match. # first .?, or the second .? -- it doesn't matter, it's just one match.
@ -89,7 +89,8 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
# showed this wasn't the case when we had a reject-and-continue before a # showed this wasn't the case when we had a reject-and-continue before a
# match. I still don't really understand what's going on here, but this # match. I still don't really understand what's going on here, but this
# workaround does resolve the issue. # workaround does resolve the issue.
while pattern.attrs.attr != ID and pattern.nr_attr > 0: while pattern.attrs.attr != ID and \
(pattern.nr_attr > 0 or pattern.nr_extra_attr > 0 or pattern.nr_py > 0):
pattern += 1 pattern += 1
return pattern.attrs.value return pattern.attrs.value
@ -101,13 +102,17 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
cdef vector[PatternStateC] new_states cdef vector[PatternStateC] new_states
cdef int nr_predicate = len(py_predicates) cdef int nr_predicate = len(py_predicates)
for i in range(states.size()): for i in range(states.size()):
if states[i].pattern.nr_py != 0: if states[i].pattern.nr_py >= 1:
update_predicate_cache(cached_py_predicates, update_predicate_cache(cached_py_predicates,
states[i].pattern, token, py_predicates) states[i].pattern, token, py_predicates)
for i in range(states.size()):
action = get_action(states[i], token.c, extra_attrs, action = get_action(states[i], token.c, extra_attrs,
cached_py_predicates, nr_predicate) cached_py_predicates)
if action == REJECT: if action == REJECT:
continue continue
# Keep only a subset of states (the active ones). Index q is the
# states which are still alive. If we reject a state, we overwrite
# it in the states list, because q doesn't advance.
state = states[i] state = states[i]
states[q] = state states[q] = state
while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND): while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND):
@ -126,7 +131,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
update_predicate_cache(cached_py_predicates, update_predicate_cache(cached_py_predicates,
states[q].pattern, token, py_predicates) states[q].pattern, token, py_predicates)
action = get_action(states[q], token.c, extra_attrs, action = get_action(states[q], token.c, extra_attrs,
cached_py_predicates, nr_predicate) cached_py_predicates)
if action == REJECT: if action == REJECT:
pass pass
elif action == ADVANCE: elif action == ADVANCE:
@ -154,8 +159,8 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
states.push_back(new_states[i]) states.push_back(new_states[i])
cdef void update_predicate_cache(char* cache, cdef int update_predicate_cache(char* cache,
const TokenPatternC* pattern, Token token, predicates): const TokenPatternC* pattern, Token token, predicates) except -1:
# If the state references any extra predicates, check whether they match. # If the state references any extra predicates, check whether they match.
# These are cached, so that we don't call these potentially expensive # These are cached, so that we don't call these potentially expensive
# Python functions more than we need to. # Python functions more than we need to.
@ -192,7 +197,7 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states)
cdef action_t get_action(PatternStateC state, cdef action_t get_action(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs, const TokenC* token, const attr_t* extra_attrs,
const char* predicate_matches, int nr_predicate) nogil: const char* predicate_matches) nogil:
'''We need to consider: '''We need to consider:
a) Does the token match the specification? [Yes, No] a) Does the token match the specification? [Yes, No]
@ -252,7 +257,7 @@ cdef action_t get_action(PatternStateC state,
Problem: If a quantifier is matching, we're adding a lot of open partials Problem: If a quantifier is matching, we're adding a lot of open partials
''' '''
cdef char is_match cdef char is_match
is_match = get_is_match(state, token, extra_attrs, predicate_matches, nr_predicate) is_match = get_is_match(state, token, extra_attrs, predicate_matches)
quantifier = get_quantifier(state) quantifier = get_quantifier(state)
is_final = get_is_final(state) is_final = get_is_final(state)
if quantifier == ZERO: if quantifier == ZERO:
@ -303,9 +308,9 @@ cdef action_t get_action(PatternStateC state,
cdef char get_is_match(PatternStateC state, cdef char get_is_match(PatternStateC state,
const TokenC* token, const attr_t* extra_attrs, const TokenC* token, const attr_t* extra_attrs,
const char* predicate_matches, int nr_predicate) nogil: const char* predicate_matches) nogil:
for i in range(nr_predicate): for i in range(state.pattern.nr_py):
if predicate_matches[i] == -1: if predicate_matches[state.pattern.py_predicates[i]] == -1:
return 0 return 0
spec = state.pattern spec = state.pattern
for attr in spec.attrs[:spec.nr_attr]: for attr in spec.attrs[:spec.nr_attr]:
@ -333,7 +338,7 @@ DEF PADDING = 5
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC)) pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
cdef int i cdef int i, index
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs): for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
pattern[i].quantifier = quantifier pattern[i].quantifier = quantifier
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC)) pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
@ -356,11 +361,13 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
pattern[i].attrs[0].attr = ID pattern[i].attrs[0].attr = ID
pattern[i].attrs[0].value = entity_id pattern[i].attrs[0].value = entity_id
pattern[i].nr_attr = 0 pattern[i].nr_attr = 0
pattern[i].nr_extra_attr = 0
pattern[i].nr_py = 0
return pattern return pattern
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil: cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
while pattern.nr_attr != 0: while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0:
pattern += 1 pattern += 1
id_attr = pattern[0].attrs[0] id_attr = pattern[0].attrs[0]
if id_attr.attr != ID: if id_attr.attr != ID:
@ -384,7 +391,6 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi
extra_predicates. extra_predicates.
""" """
tokens = [] tokens = []
seen_predicates = {}
for spec in token_specs: for spec in token_specs:
if not spec: if not spec:
# Signifier for 'any token' # Signifier for 'any token'
@ -393,7 +399,7 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi
ops = _get_operators(spec) ops = _get_operators(spec)
attr_values = _get_attr_values(spec, string_store) attr_values = _get_attr_values(spec, string_store)
extensions = _get_extensions(spec, string_store, extensions_table) extensions = _get_extensions(spec, string_store, extensions_table)
predicates = _get_extra_predicates(spec, extra_predicates, seen_predicates) predicates = _get_extra_predicates(spec, extra_predicates)
for op in ops: for op in ops:
tokens.append((op, list(attr_values), list(extensions), list(predicates))) tokens.append((op, list(attr_values), list(extensions), list(predicates)))
return tokens return tokens
@ -430,6 +436,7 @@ class _RegexPredicate(object):
self.value = re.compile(value) self.value = re.compile(value)
self.predicate = predicate self.predicate = predicate
self.is_extension = is_extension self.is_extension = is_extension
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
assert self.predicate == 'REGEX' assert self.predicate == 'REGEX'
def __call__(self, Token token): def __call__(self, Token token):
@ -447,6 +454,7 @@ class _SetMemberPredicate(object):
self.value = set(get_string_id(v) for v in value) self.value = set(get_string_id(v) for v in value)
self.predicate = predicate self.predicate = predicate
self.is_extension = is_extension self.is_extension = is_extension
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
assert self.predicate in ('IN', 'NOT_IN') assert self.predicate in ('IN', 'NOT_IN')
def __call__(self, Token token): def __call__(self, Token token):
@ -459,6 +467,9 @@ class _SetMemberPredicate(object):
else: else:
return value not in self.value return value not in self.value
def __repr__(self):
return repr(('SetMemberPredicate', self.i, self.attr, self.value, self.predicate))
class _ComparisonPredicate(object): class _ComparisonPredicate(object):
def __init__(self, i, attr, value, predicate, is_extension=False): def __init__(self, i, attr, value, predicate, is_extension=False):
@ -467,6 +478,7 @@ class _ComparisonPredicate(object):
self.value = value self.value = value
self.predicate = predicate self.predicate = predicate
self.is_extension = is_extension self.is_extension = is_extension
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
assert self.predicate in ('==', '!=', '>=', '<=', '>', '<') assert self.predicate in ('==', '!=', '>=', '<=', '>', '<')
def __call__(self, Token token): def __call__(self, Token token):
@ -488,7 +500,7 @@ class _ComparisonPredicate(object):
return value < self.value return value < self.value
def _get_extra_predicates(spec, extra_predicates, seen_predicates): def _get_extra_predicates(spec, extra_predicates):
predicate_types = { predicate_types = {
'REGEX': _RegexPredicate, 'REGEX': _RegexPredicate,
'IN': _SetMemberPredicate, 'IN': _SetMemberPredicate,
@ -499,6 +511,7 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates):
'>': _ComparisonPredicate, '>': _ComparisonPredicate,
'<': _ComparisonPredicate, '<': _ComparisonPredicate,
} }
seen_predicates = {pred.key: pred.i for pred in extra_predicates}
output = [] output = []
for attr, value in spec.items(): for attr, value in spec.items():
if isinstance(attr, basestring): if isinstance(attr, basestring):
@ -516,16 +529,15 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates):
if isinstance(value, dict): if isinstance(value, dict):
for type_, cls in predicate_types.items(): for type_, cls in predicate_types.items():
if type_ in value: if type_ in value:
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True)) predicate = cls(len(extra_predicates), attr, value[type_], type_)
# Don't create a redundant predicates. # Don't create a redundant predicates.
# This helps with efficiency, as we're caching the results. # This helps with efficiency, as we're caching the results.
if key in seen_predicates: if predicate.key in seen_predicates:
output.append(seen_predicates[key]) output.append(seen_predicates[predicate.key])
else: else:
predicate = cls(len(extra_predicates), attr, value[type_], type_)
extra_predicates.append(predicate) extra_predicates.append(predicate)
output.append(predicate.i) output.append(predicate.i)
seen_predicates[key] = predicate.i seen_predicates[predicate.key] = predicate.i
return output return output

View File

@ -207,14 +207,13 @@ def test_matcher_set_value(en_vocab):
assert len(matches) == 0 assert len(matches) == 0
@pytest.mark.xfail
def test_matcher_set_value_operator(en_vocab): def test_matcher_set_value_operator(en_vocab):
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
pattern = [{"ORTH": {"IN": ["a", "the"]}, "OP": "?"}, {"ORTH": "house"}] pattern = [{"ORTH": {"IN": ["a", "the"]}, "OP": "?"}, {"ORTH": "house"}]
matcher.add("DET_HOUSE", None, pattern) matcher.add("DET_HOUSE", None, pattern)
doc = Doc(en_vocab, words=["In", "a", "house"]) doc = Doc(en_vocab, words=["In", "a", "house"])
matches = matcher(doc) matches = matcher(doc)
assert len(matches) == 1 assert len(matches) == 2
doc = Doc(en_vocab, words=["my", "house"]) doc = Doc(en_vocab, words=["my", "house"])
matches = matcher(doc) matches = matcher(doc)
assert len(matches) == 1 assert len(matches) == 1

View File

@ -6,7 +6,6 @@ from spacy.matcher import Matcher
from spacy.tokens import Token, Doc from spacy.tokens import Token, Doc
@pytest.mark.xfail
def test_issue1971(en_vocab): def test_issue1971(en_vocab):
# Possibly related to #2675 and #2671? # Possibly related to #2675 and #2671?
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
@ -22,21 +21,20 @@ def test_issue1971(en_vocab):
# We could also assert length 1 here, but this is more conclusive, because # We could also assert length 1 here, but this is more conclusive, because
# the real problem here is that it returns a duplicate match for a match_id # the real problem here is that it returns a duplicate match for a match_id
# that's not actually in the vocab! # that's not actually in the vocab!
assert all(match_id in en_vocab.strings for match_id, start, end in matcher(doc)) matches = matcher(doc)
assert all([match_id in en_vocab.strings for match_id, start, end in matches])
@pytest.mark.xfail
def test_issue_1971_2(en_vocab): def test_issue_1971_2(en_vocab):
matcher = Matcher(en_vocab) matcher = Matcher(en_vocab)
pattern1 = [{"LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}] pattern1 = [{"ORTH": "EUR", "LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}]
pattern2 = list(reversed(pattern1)) pattern2 = [{"LIKE_NUM": True}, {"ORTH": "EUR"}] #{"IN": ["EUR"]}}]
doc = Doc(en_vocab, words=["EUR", "10", "is", "10", "EUR"]) doc = Doc(en_vocab, words=["EUR", "10", "is", "10", "EUR"])
matcher.add("TEST", None, pattern1, pattern2) matcher.add("TEST1", None, pattern1, pattern2)
matches = matcher(doc) matches = matcher(doc)
assert len(matches) == 2 assert len(matches) == 2
@pytest.mark.xfail
def test_issue_1971_3(en_vocab): def test_issue_1971_3(en_vocab):
"""Test that pattern matches correctly for multiple extension attributes.""" """Test that pattern matches correctly for multiple extension attributes."""
Token.set_extension("a", default=1) Token.set_extension("a", default=1)
@ -50,7 +48,6 @@ def test_issue_1971_3(en_vocab):
assert matches == sorted([("A", 0, 1), ("A", 1, 2), ("B", 0, 1), ("B", 1, 2)]) assert matches == sorted([("A", 0, 1), ("A", 1, 2), ("B", 0, 1), ("B", 1, 2)])
# @pytest.mark.xfail
def test_issue_1971_4(en_vocab): def test_issue_1971_4(en_vocab):
"""Test that pattern matches correctly with multiple extension attribute """Test that pattern matches correctly with multiple extension attribute
values on a single token. values on a single token.