mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Fix Issue #1450: Off-by-1 in * and ? matches
Patterns that end in variable-length operators e.g. * and ? now end on the correct token. Previously, they were off by 1: the next token was pulled into the match, even if that's where the pattern failed.
This commit is contained in:
parent
391d5ef0d1
commit
4bea65a1a8
|
@ -69,6 +69,7 @@ cdef enum action_t:
|
||||||
REPEAT
|
REPEAT
|
||||||
ACCEPT
|
ACCEPT
|
||||||
ADVANCE_ZERO
|
ADVANCE_ZERO
|
||||||
|
ACCEPT_PREV
|
||||||
PANIC
|
PANIC
|
||||||
|
|
||||||
# A "match expression" conists of one or more token patterns
|
# A "match expression" conists of one or more token patterns
|
||||||
|
@ -120,24 +121,27 @@ cdef attr_t get_pattern_key(const TokenPatternC* pattern) except 0:
|
||||||
|
|
||||||
|
|
||||||
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
|
cdef int get_action(const TokenPatternC* pattern, const TokenC* token) nogil:
|
||||||
|
lookahead = &pattern[1]
|
||||||
for attr in pattern.attrs[:pattern.nr_attr]:
|
for attr in pattern.attrs[:pattern.nr_attr]:
|
||||||
if get_token_attr(token, attr.attr) != attr.value:
|
if get_token_attr(token, attr.attr) != attr.value:
|
||||||
if pattern.quantifier == ONE:
|
if pattern.quantifier == ONE:
|
||||||
return REJECT
|
return REJECT
|
||||||
elif pattern.quantifier == ZERO:
|
elif pattern.quantifier == ZERO:
|
||||||
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE
|
return ACCEPT if lookahead.nr_attr == 0 else ADVANCE
|
||||||
elif pattern.quantifier in (ZERO_ONE, ZERO_PLUS):
|
elif pattern.quantifier in (ZERO_ONE, ZERO_PLUS):
|
||||||
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE_ZERO
|
return ACCEPT_PREV if lookahead.nr_attr == 0 else ADVANCE_ZERO
|
||||||
else:
|
else:
|
||||||
return PANIC
|
return PANIC
|
||||||
if pattern.quantifier == ZERO:
|
if pattern.quantifier == ZERO:
|
||||||
return REJECT
|
return REJECT
|
||||||
|
elif lookahead.nr_attr == 0:
|
||||||
|
return ACCEPT
|
||||||
elif pattern.quantifier in (ONE, ZERO_ONE):
|
elif pattern.quantifier in (ONE, ZERO_ONE):
|
||||||
return ACCEPT if (pattern+1).nr_attr == 0 else ADVANCE
|
return ADVANCE
|
||||||
elif pattern.quantifier == ZERO_PLUS:
|
elif pattern.quantifier == ZERO_PLUS:
|
||||||
# This is a bandaid over the 'shadowing' problem described here:
|
# This is a bandaid over the 'shadowing' problem described here:
|
||||||
# https://github.com/explosion/spaCy/issues/864
|
# https://github.com/explosion/spaCy/issues/864
|
||||||
next_action = get_action(pattern+1, token)
|
next_action = get_action(lookahead, token)
|
||||||
if next_action is REJECT:
|
if next_action is REJECT:
|
||||||
return REPEAT
|
return REPEAT
|
||||||
else:
|
else:
|
||||||
|
@ -345,6 +349,9 @@ cdef class Matcher:
|
||||||
while action == ADVANCE_ZERO:
|
while action == ADVANCE_ZERO:
|
||||||
state.second += 1
|
state.second += 1
|
||||||
action = get_action(state.second, token)
|
action = get_action(state.second, token)
|
||||||
|
if action == PANIC:
|
||||||
|
raise Exception("Error selecting action in matcher")
|
||||||
|
|
||||||
if action == REPEAT:
|
if action == REPEAT:
|
||||||
# Leave the state in the queue, and advance to next slot
|
# Leave the state in the queue, and advance to next slot
|
||||||
# (i.e. we don't overwrite -- we want to greedily match more
|
# (i.e. we don't overwrite -- we want to greedily match more
|
||||||
|
@ -356,14 +363,15 @@ cdef class Matcher:
|
||||||
partials[q] = state
|
partials[q] = state
|
||||||
partials[q].second += 1
|
partials[q].second += 1
|
||||||
q += 1
|
q += 1
|
||||||
elif action == ACCEPT:
|
elif action in (ACCEPT, ACCEPT_PREV):
|
||||||
# TODO: What to do about patterns starting with ZERO? Need to
|
# TODO: What to do about patterns starting with ZERO? Need to
|
||||||
# adjust the start position.
|
# adjust the start position.
|
||||||
start = state.first
|
start = state.first
|
||||||
end = token_i+1
|
end = token_i+1 if action == ACCEPT else token_i
|
||||||
ent_id = state.second[1].attrs[0].value
|
ent_id = state.second[1].attrs[0].value
|
||||||
label = state.second[1].attrs[1].value
|
label = state.second[1].attrs[1].value
|
||||||
matches.append((ent_id, start, end))
|
matches.append((ent_id, start, end))
|
||||||
|
|
||||||
partials.resize(q)
|
partials.resize(q)
|
||||||
# Check whether we open any new patterns on this token
|
# Check whether we open any new patterns on this token
|
||||||
for pattern in self.patterns:
|
for pattern in self.patterns:
|
||||||
|
@ -383,9 +391,9 @@ cdef class Matcher:
|
||||||
state.first = token_i
|
state.first = token_i
|
||||||
state.second = pattern + 1
|
state.second = pattern + 1
|
||||||
partials.push_back(state)
|
partials.push_back(state)
|
||||||
elif action == ACCEPT:
|
elif action in (ACCEPT, ACCEPT_PREV):
|
||||||
start = token_i
|
start = token_i
|
||||||
end = token_i+1
|
end = token_i+1 if action == ACCEPT else token_i
|
||||||
ent_id = pattern[1].attrs[0].value
|
ent_id = pattern[1].attrs[0].value
|
||||||
label = pattern[1].attrs[1].value
|
label = pattern[1].attrs[1].value
|
||||||
matches.append((ent_id, start, end))
|
matches.append((ent_id, start, end))
|
||||||
|
|
58
spacy/tests/regression/test_issue1450.py
Normal file
58
spacy/tests/regression/test_issue1450.py
Normal file
|
@ -0,0 +1,58 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ...matcher import Matcher
|
||||||
|
from ...tokens import Doc
|
||||||
|
from ...vocab import Vocab
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
'string,start,end',
|
||||||
|
[
|
||||||
|
('a', 0, 1),
|
||||||
|
('a b', 0, 2),
|
||||||
|
('a c', 0, 1),
|
||||||
|
('a b c', 0, 2),
|
||||||
|
('a b b c', 0, 2),
|
||||||
|
('a b b', 0, 2),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_issue1450_matcher_end_zero_plus(string, start, end):
|
||||||
|
'''Test matcher works when patterns end with * operator.
|
||||||
|
|
||||||
|
Original example (rewritten to avoid model usage)
|
||||||
|
|
||||||
|
nlp = spacy.load('en_core_web_sm')
|
||||||
|
matcher = Matcher(nlp.vocab)
|
||||||
|
matcher.add(
|
||||||
|
"TSTEND",
|
||||||
|
on_match_1,
|
||||||
|
[
|
||||||
|
{TAG: "JJ", LOWER: "new"},
|
||||||
|
{TAG: "NN", 'OP': "*"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
doc = nlp(u'Could you create a new ticket for me?')
|
||||||
|
print([(w.tag_, w.text, w.lower_) for w in doc])
|
||||||
|
matches = matcher(doc)
|
||||||
|
print(matches)
|
||||||
|
assert len(matches) == 1
|
||||||
|
assert matches[0][1] == 4
|
||||||
|
assert matches[0][2] == 5
|
||||||
|
'''
|
||||||
|
matcher = Matcher(Vocab())
|
||||||
|
matcher.add(
|
||||||
|
"TSTEND",
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
{'ORTH': "a"},
|
||||||
|
{'ORTH': "b", 'OP': "*"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
doc = Doc(Vocab(), words=string.split())
|
||||||
|
matches = matcher(doc)
|
||||||
|
if start is None or end is None:
|
||||||
|
assert matches == []
|
||||||
|
|
||||||
|
assert matches[0][1] == start
|
||||||
|
assert matches[0][2] == end
|
|
@ -3,6 +3,7 @@ from __future__ import unicode_literals
|
||||||
|
|
||||||
from ..matcher import Matcher, PhraseMatcher
|
from ..matcher import Matcher, PhraseMatcher
|
||||||
from .util import get_doc
|
from .util import get_doc
|
||||||
|
from ..tokens import Doc
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
@ -212,3 +213,24 @@ def test_operator_combos(matcher):
|
||||||
assert matches, (string, pattern_str)
|
assert matches, (string, pattern_str)
|
||||||
else:
|
else:
|
||||||
assert not matches, (string, pattern_str)
|
assert not matches, (string, pattern_str)
|
||||||
|
|
||||||
|
|
||||||
|
def test_matcher_end_zero_plus(matcher):
|
||||||
|
'''Test matcher works when patterns end with * operator. (issue 1450)'''
|
||||||
|
matcher = Matcher(matcher.vocab)
|
||||||
|
matcher.add(
|
||||||
|
"TSTEND",
|
||||||
|
None,
|
||||||
|
[
|
||||||
|
{'ORTH': "a"},
|
||||||
|
{'ORTH': "b", 'OP': "*"}
|
||||||
|
]
|
||||||
|
)
|
||||||
|
nlp = lambda string: Doc(matcher.vocab, words=string.split())
|
||||||
|
assert len(matcher(nlp(u'a'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a c'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b c'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b b c'))) == 1
|
||||||
|
assert len(matcher(nlp(u'a b b'))) == 1
|
||||||
|
|
Loading…
Reference in New Issue
Block a user