mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
* Fix matching on extension attrs and predicates * Fix detection of match_id when using extension attributes. The match ID is stored as the last entry in the pattern. We were checking for this with nr_attr == 0, which didn't account for extension attributes. * Fix handling of predicates. The wrong count was being passed through, so even patterns that didn't have a predicate were being checked. * Fix regex pattern * Fix matcher set value test
This commit is contained in:
parent
f73d01aa32
commit
0d1ca15b13
|
@ -44,7 +44,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
|
predicate_cache = <char*>mem.alloc(doc.length * len(predicates), sizeof(char))
|
||||||
if extensions is not None and len(extensions) >= 1:
|
if extensions is not None and len(extensions) >= 1:
|
||||||
nr_extra_attr = max(extensions.values())
|
nr_extra_attr = max(extensions.values()) + 1
|
||||||
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
|
extra_attr_values = <attr_t*>mem.alloc(doc.length * nr_extra_attr, sizeof(attr_t))
|
||||||
else:
|
else:
|
||||||
nr_extra_attr = 0
|
nr_extra_attr = 0
|
||||||
|
@ -60,9 +60,8 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
||||||
for i in range(doc.length):
|
for i in range(doc.length):
|
||||||
for j in range(n):
|
for j in range(n):
|
||||||
states.push_back(PatternStateC(patterns[j], i, 0))
|
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||||
transition_states(states, matches, predicate_cache,
|
transition_states(states, matches, &predicate_cache[i],
|
||||||
doc[i], extra_attr_values, predicates)
|
doc[i], extra_attr_values, predicates)
|
||||||
predicate_cache += nr_predicate
|
|
||||||
extra_attr_values += nr_extra_attr
|
extra_attr_values += nr_extra_attr
|
||||||
# Handle matches that end in 0-width patterns
|
# Handle matches that end in 0-width patterns
|
||||||
finish_states(matches, states)
|
finish_states(matches, states)
|
||||||
|
@ -74,6 +73,7 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None,
|
||||||
matches[i].start,
|
matches[i].start,
|
||||||
matches[i].start+matches[i].length
|
matches[i].start+matches[i].length
|
||||||
)
|
)
|
||||||
|
|
||||||
# We need to deduplicate, because we could otherwise arrive at the same
|
# We need to deduplicate, because we could otherwise arrive at the same
|
||||||
# match through two paths, e.g. .?.? matching 'a'. Are we matching the
|
# match through two paths, e.g. .?.? matching 'a'. Are we matching the
|
||||||
# first .?, or the second .? -- it doesn't matter, it's just one match.
|
# first .?, or the second .? -- it doesn't matter, it's just one match.
|
||||||
|
@ -89,7 +89,8 @@ cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
|
||||||
# showed this wasn't the case when we had a reject-and-continue before a
|
# showed this wasn't the case when we had a reject-and-continue before a
|
||||||
# match. I still don't really understand what's going on here, but this
|
# match. I still don't really understand what's going on here, but this
|
||||||
# workaround does resolve the issue.
|
# workaround does resolve the issue.
|
||||||
while pattern.attrs.attr != ID and pattern.nr_attr > 0:
|
while pattern.attrs.attr != ID and \
|
||||||
|
(pattern.nr_attr > 0 or pattern.nr_extra_attr > 0 or pattern.nr_py > 0):
|
||||||
pattern += 1
|
pattern += 1
|
||||||
return pattern.attrs.value
|
return pattern.attrs.value
|
||||||
|
|
||||||
|
@ -101,13 +102,17 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
||||||
cdef vector[PatternStateC] new_states
|
cdef vector[PatternStateC] new_states
|
||||||
cdef int nr_predicate = len(py_predicates)
|
cdef int nr_predicate = len(py_predicates)
|
||||||
for i in range(states.size()):
|
for i in range(states.size()):
|
||||||
if states[i].pattern.nr_py != 0:
|
if states[i].pattern.nr_py >= 1:
|
||||||
update_predicate_cache(cached_py_predicates,
|
update_predicate_cache(cached_py_predicates,
|
||||||
states[i].pattern, token, py_predicates)
|
states[i].pattern, token, py_predicates)
|
||||||
|
for i in range(states.size()):
|
||||||
action = get_action(states[i], token.c, extra_attrs,
|
action = get_action(states[i], token.c, extra_attrs,
|
||||||
cached_py_predicates, nr_predicate)
|
cached_py_predicates)
|
||||||
if action == REJECT:
|
if action == REJECT:
|
||||||
continue
|
continue
|
||||||
|
# Keep only a subset of states (the active ones). Index q is the
|
||||||
|
# states which are still alive. If we reject a state, we overwrite
|
||||||
|
# it in the states list, because q doesn't advance.
|
||||||
state = states[i]
|
state = states[i]
|
||||||
states[q] = state
|
states[q] = state
|
||||||
while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND):
|
while action in (RETRY, RETRY_ADVANCE, RETRY_EXTEND):
|
||||||
|
@ -126,7 +131,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
||||||
update_predicate_cache(cached_py_predicates,
|
update_predicate_cache(cached_py_predicates,
|
||||||
states[q].pattern, token, py_predicates)
|
states[q].pattern, token, py_predicates)
|
||||||
action = get_action(states[q], token.c, extra_attrs,
|
action = get_action(states[q], token.c, extra_attrs,
|
||||||
cached_py_predicates, nr_predicate)
|
cached_py_predicates)
|
||||||
if action == REJECT:
|
if action == REJECT:
|
||||||
pass
|
pass
|
||||||
elif action == ADVANCE:
|
elif action == ADVANCE:
|
||||||
|
@ -154,8 +159,8 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
|
||||||
states.push_back(new_states[i])
|
states.push_back(new_states[i])
|
||||||
|
|
||||||
|
|
||||||
cdef void update_predicate_cache(char* cache,
|
cdef int update_predicate_cache(char* cache,
|
||||||
const TokenPatternC* pattern, Token token, predicates):
|
const TokenPatternC* pattern, Token token, predicates) except -1:
|
||||||
# If the state references any extra predicates, check whether they match.
|
# If the state references any extra predicates, check whether they match.
|
||||||
# These are cached, so that we don't call these potentially expensive
|
# These are cached, so that we don't call these potentially expensive
|
||||||
# Python functions more than we need to.
|
# Python functions more than we need to.
|
||||||
|
@ -192,7 +197,7 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states)
|
||||||
|
|
||||||
cdef action_t get_action(PatternStateC state,
|
cdef action_t get_action(PatternStateC state,
|
||||||
const TokenC* token, const attr_t* extra_attrs,
|
const TokenC* token, const attr_t* extra_attrs,
|
||||||
const char* predicate_matches, int nr_predicate) nogil:
|
const char* predicate_matches) nogil:
|
||||||
'''We need to consider:
|
'''We need to consider:
|
||||||
|
|
||||||
a) Does the token match the specification? [Yes, No]
|
a) Does the token match the specification? [Yes, No]
|
||||||
|
@ -252,7 +257,7 @@ cdef action_t get_action(PatternStateC state,
|
||||||
Problem: If a quantifier is matching, we're adding a lot of open partials
|
Problem: If a quantifier is matching, we're adding a lot of open partials
|
||||||
'''
|
'''
|
||||||
cdef char is_match
|
cdef char is_match
|
||||||
is_match = get_is_match(state, token, extra_attrs, predicate_matches, nr_predicate)
|
is_match = get_is_match(state, token, extra_attrs, predicate_matches)
|
||||||
quantifier = get_quantifier(state)
|
quantifier = get_quantifier(state)
|
||||||
is_final = get_is_final(state)
|
is_final = get_is_final(state)
|
||||||
if quantifier == ZERO:
|
if quantifier == ZERO:
|
||||||
|
@ -303,9 +308,9 @@ cdef action_t get_action(PatternStateC state,
|
||||||
|
|
||||||
cdef char get_is_match(PatternStateC state,
|
cdef char get_is_match(PatternStateC state,
|
||||||
const TokenC* token, const attr_t* extra_attrs,
|
const TokenC* token, const attr_t* extra_attrs,
|
||||||
const char* predicate_matches, int nr_predicate) nogil:
|
const char* predicate_matches) nogil:
|
||||||
for i in range(nr_predicate):
|
for i in range(state.pattern.nr_py):
|
||||||
if predicate_matches[i] == -1:
|
if predicate_matches[state.pattern.py_predicates[i]] == -1:
|
||||||
return 0
|
return 0
|
||||||
spec = state.pattern
|
spec = state.pattern
|
||||||
for attr in spec.attrs[:spec.nr_attr]:
|
for attr in spec.attrs[:spec.nr_attr]:
|
||||||
|
@ -333,7 +338,7 @@ DEF PADDING = 5
|
||||||
|
|
||||||
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
|
cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL:
|
||||||
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
pattern = <TokenPatternC*>mem.alloc(len(token_specs) + 1, sizeof(TokenPatternC))
|
||||||
cdef int i
|
cdef int i, index
|
||||||
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs):
|
||||||
pattern[i].quantifier = quantifier
|
pattern[i].quantifier = quantifier
|
||||||
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC))
|
||||||
|
@ -356,11 +361,13 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs)
|
||||||
pattern[i].attrs[0].attr = ID
|
pattern[i].attrs[0].attr = ID
|
||||||
pattern[i].attrs[0].value = entity_id
|
pattern[i].attrs[0].value = entity_id
|
||||||
pattern[i].nr_attr = 0
|
pattern[i].nr_attr = 0
|
||||||
|
pattern[i].nr_extra_attr = 0
|
||||||
|
pattern[i].nr_py = 0
|
||||||
return pattern
|
return pattern
|
||||||
|
|
||||||
|
|
||||||
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil:
|
||||||
while pattern.nr_attr != 0:
|
while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0:
|
||||||
pattern += 1
|
pattern += 1
|
||||||
id_attr = pattern[0].attrs[0]
|
id_attr = pattern[0].attrs[0]
|
||||||
if id_attr.attr != ID:
|
if id_attr.attr != ID:
|
||||||
|
@ -384,7 +391,6 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi
|
||||||
extra_predicates.
|
extra_predicates.
|
||||||
"""
|
"""
|
||||||
tokens = []
|
tokens = []
|
||||||
seen_predicates = {}
|
|
||||||
for spec in token_specs:
|
for spec in token_specs:
|
||||||
if not spec:
|
if not spec:
|
||||||
# Signifier for 'any token'
|
# Signifier for 'any token'
|
||||||
|
@ -393,7 +399,7 @@ def _preprocess_pattern(token_specs, string_store, extensions_table, extra_predi
|
||||||
ops = _get_operators(spec)
|
ops = _get_operators(spec)
|
||||||
attr_values = _get_attr_values(spec, string_store)
|
attr_values = _get_attr_values(spec, string_store)
|
||||||
extensions = _get_extensions(spec, string_store, extensions_table)
|
extensions = _get_extensions(spec, string_store, extensions_table)
|
||||||
predicates = _get_extra_predicates(spec, extra_predicates, seen_predicates)
|
predicates = _get_extra_predicates(spec, extra_predicates)
|
||||||
for op in ops:
|
for op in ops:
|
||||||
tokens.append((op, list(attr_values), list(extensions), list(predicates)))
|
tokens.append((op, list(attr_values), list(extensions), list(predicates)))
|
||||||
return tokens
|
return tokens
|
||||||
|
@ -430,6 +436,7 @@ class _RegexPredicate(object):
|
||||||
self.value = re.compile(value)
|
self.value = re.compile(value)
|
||||||
self.predicate = predicate
|
self.predicate = predicate
|
||||||
self.is_extension = is_extension
|
self.is_extension = is_extension
|
||||||
|
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||||
assert self.predicate == 'REGEX'
|
assert self.predicate == 'REGEX'
|
||||||
|
|
||||||
def __call__(self, Token token):
|
def __call__(self, Token token):
|
||||||
|
@ -447,6 +454,7 @@ class _SetMemberPredicate(object):
|
||||||
self.value = set(get_string_id(v) for v in value)
|
self.value = set(get_string_id(v) for v in value)
|
||||||
self.predicate = predicate
|
self.predicate = predicate
|
||||||
self.is_extension = is_extension
|
self.is_extension = is_extension
|
||||||
|
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||||
assert self.predicate in ('IN', 'NOT_IN')
|
assert self.predicate in ('IN', 'NOT_IN')
|
||||||
|
|
||||||
def __call__(self, Token token):
|
def __call__(self, Token token):
|
||||||
|
@ -459,6 +467,9 @@ class _SetMemberPredicate(object):
|
||||||
else:
|
else:
|
||||||
return value not in self.value
|
return value not in self.value
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return repr(('SetMemberPredicate', self.i, self.attr, self.value, self.predicate))
|
||||||
|
|
||||||
|
|
||||||
class _ComparisonPredicate(object):
|
class _ComparisonPredicate(object):
|
||||||
def __init__(self, i, attr, value, predicate, is_extension=False):
|
def __init__(self, i, attr, value, predicate, is_extension=False):
|
||||||
|
@ -467,6 +478,7 @@ class _ComparisonPredicate(object):
|
||||||
self.value = value
|
self.value = value
|
||||||
self.predicate = predicate
|
self.predicate = predicate
|
||||||
self.is_extension = is_extension
|
self.is_extension = is_extension
|
||||||
|
self.key = (attr, self.predicate, srsly.json_dumps(value, sort_keys=True))
|
||||||
assert self.predicate in ('==', '!=', '>=', '<=', '>', '<')
|
assert self.predicate in ('==', '!=', '>=', '<=', '>', '<')
|
||||||
|
|
||||||
def __call__(self, Token token):
|
def __call__(self, Token token):
|
||||||
|
@ -488,7 +500,7 @@ class _ComparisonPredicate(object):
|
||||||
return value < self.value
|
return value < self.value
|
||||||
|
|
||||||
|
|
||||||
def _get_extra_predicates(spec, extra_predicates, seen_predicates):
|
def _get_extra_predicates(spec, extra_predicates):
|
||||||
predicate_types = {
|
predicate_types = {
|
||||||
'REGEX': _RegexPredicate,
|
'REGEX': _RegexPredicate,
|
||||||
'IN': _SetMemberPredicate,
|
'IN': _SetMemberPredicate,
|
||||||
|
@ -499,6 +511,7 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates):
|
||||||
'>': _ComparisonPredicate,
|
'>': _ComparisonPredicate,
|
||||||
'<': _ComparisonPredicate,
|
'<': _ComparisonPredicate,
|
||||||
}
|
}
|
||||||
|
seen_predicates = {pred.key: pred.i for pred in extra_predicates}
|
||||||
output = []
|
output = []
|
||||||
for attr, value in spec.items():
|
for attr, value in spec.items():
|
||||||
if isinstance(attr, basestring):
|
if isinstance(attr, basestring):
|
||||||
|
@ -516,16 +529,15 @@ def _get_extra_predicates(spec, extra_predicates, seen_predicates):
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
for type_, cls in predicate_types.items():
|
for type_, cls in predicate_types.items():
|
||||||
if type_ in value:
|
if type_ in value:
|
||||||
key = (attr, type_, srsly.json_dumps(value[type_], sort_keys=True))
|
predicate = cls(len(extra_predicates), attr, value[type_], type_)
|
||||||
# Don't create a redundant predicates.
|
# Don't create a redundant predicates.
|
||||||
# This helps with efficiency, as we're caching the results.
|
# This helps with efficiency, as we're caching the results.
|
||||||
if key in seen_predicates:
|
if predicate.key in seen_predicates:
|
||||||
output.append(seen_predicates[key])
|
output.append(seen_predicates[predicate.key])
|
||||||
else:
|
else:
|
||||||
predicate = cls(len(extra_predicates), attr, value[type_], type_)
|
|
||||||
extra_predicates.append(predicate)
|
extra_predicates.append(predicate)
|
||||||
output.append(predicate.i)
|
output.append(predicate.i)
|
||||||
seen_predicates[key] = predicate.i
|
seen_predicates[predicate.key] = predicate.i
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -207,14 +207,13 @@ def test_matcher_set_value(en_vocab):
|
||||||
assert len(matches) == 0
|
assert len(matches) == 0
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_matcher_set_value_operator(en_vocab):
|
def test_matcher_set_value_operator(en_vocab):
|
||||||
matcher = Matcher(en_vocab)
|
matcher = Matcher(en_vocab)
|
||||||
pattern = [{"ORTH": {"IN": ["a", "the"]}, "OP": "?"}, {"ORTH": "house"}]
|
pattern = [{"ORTH": {"IN": ["a", "the"]}, "OP": "?"}, {"ORTH": "house"}]
|
||||||
matcher.add("DET_HOUSE", None, pattern)
|
matcher.add("DET_HOUSE", None, pattern)
|
||||||
doc = Doc(en_vocab, words=["In", "a", "house"])
|
doc = Doc(en_vocab, words=["In", "a", "house"])
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 2
|
||||||
doc = Doc(en_vocab, words=["my", "house"])
|
doc = Doc(en_vocab, words=["my", "house"])
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
assert len(matches) == 1
|
assert len(matches) == 1
|
||||||
|
|
|
@ -6,7 +6,6 @@ from spacy.matcher import Matcher
|
||||||
from spacy.tokens import Token, Doc
|
from spacy.tokens import Token, Doc
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_issue1971(en_vocab):
|
def test_issue1971(en_vocab):
|
||||||
# Possibly related to #2675 and #2671?
|
# Possibly related to #2675 and #2671?
|
||||||
matcher = Matcher(en_vocab)
|
matcher = Matcher(en_vocab)
|
||||||
|
@ -22,21 +21,20 @@ def test_issue1971(en_vocab):
|
||||||
# We could also assert length 1 here, but this is more conclusive, because
|
# We could also assert length 1 here, but this is more conclusive, because
|
||||||
# the real problem here is that it returns a duplicate match for a match_id
|
# the real problem here is that it returns a duplicate match for a match_id
|
||||||
# that's not actually in the vocab!
|
# that's not actually in the vocab!
|
||||||
assert all(match_id in en_vocab.strings for match_id, start, end in matcher(doc))
|
matches = matcher(doc)
|
||||||
|
assert all([match_id in en_vocab.strings for match_id, start, end in matches])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_issue_1971_2(en_vocab):
|
def test_issue_1971_2(en_vocab):
|
||||||
matcher = Matcher(en_vocab)
|
matcher = Matcher(en_vocab)
|
||||||
pattern1 = [{"LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}]
|
pattern1 = [{"ORTH": "EUR", "LOWER": {"IN": ["eur"]}}, {"LIKE_NUM": True}]
|
||||||
pattern2 = list(reversed(pattern1))
|
pattern2 = [{"LIKE_NUM": True}, {"ORTH": "EUR"}] #{"IN": ["EUR"]}}]
|
||||||
doc = Doc(en_vocab, words=["EUR", "10", "is", "10", "EUR"])
|
doc = Doc(en_vocab, words=["EUR", "10", "is", "10", "EUR"])
|
||||||
matcher.add("TEST", None, pattern1, pattern2)
|
matcher.add("TEST1", None, pattern1, pattern2)
|
||||||
matches = matcher(doc)
|
matches = matcher(doc)
|
||||||
assert len(matches) == 2
|
assert len(matches) == 2
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail
|
|
||||||
def test_issue_1971_3(en_vocab):
|
def test_issue_1971_3(en_vocab):
|
||||||
"""Test that pattern matches correctly for multiple extension attributes."""
|
"""Test that pattern matches correctly for multiple extension attributes."""
|
||||||
Token.set_extension("a", default=1)
|
Token.set_extension("a", default=1)
|
||||||
|
@ -50,7 +48,6 @@ def test_issue_1971_3(en_vocab):
|
||||||
assert matches == sorted([("A", 0, 1), ("A", 1, 2), ("B", 0, 1), ("B", 1, 2)])
|
assert matches == sorted([("A", 0, 1), ("A", 1, 2), ("B", 0, 1), ("B", 1, 2)])
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.xfail
|
|
||||||
def test_issue_1971_4(en_vocab):
|
def test_issue_1971_4(en_vocab):
|
||||||
"""Test that pattern matches correctly with multiple extension attribute
|
"""Test that pattern matches correctly with multiple extension attribute
|
||||||
values on a single token.
|
values on a single token.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user