Fix #2671: Incorrect match ID on some patterns

This commit is contained in:
Matthew Honnibal 2018-08-15 16:19:08 +02:00
parent f12b9190f6
commit ce512e1d47
2 changed files with 12 additions and 3 deletions

View File

@ -90,6 +90,16 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
for i in range(matches.size())] for i in range(matches.size())]
cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil:
# The code was originally designed to always have pattern[1].attrs.value
# be the ent_id when we get to the end of a pattern. However, Issue #2671
# showed this wasn't the case when we had a reject-and-continue before a
# match. I still don't really understand what's going on here, but this
# workaround does resolve the issue.
while pattern.attrs.attr != ID and pattern.nr_attr > 0:
pattern += 1
return pattern.attrs.value
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
const TokenC* token, const attr_t* extra_attrs) except *: const TokenC* token, const attr_t* extra_attrs) except *:
@ -115,7 +125,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match
states[q].length += 1 states[q].length += 1
q += 1 q += 1
else: else:
ent_id = state.pattern[1].attrs.value ent_id = get_ent_id(&state.pattern[1])
if action == MATCH: if action == MATCH:
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, MatchC(pattern_id=ent_id, start=state.start,
@ -143,7 +153,7 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states)
while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE): while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE):
is_final = get_is_final(state) is_final = get_is_final(state)
if is_final: if is_final:
ent_id = state.pattern[1].attrs.value ent_id = get_ent_id(state.pattern)
matches.push_back( matches.push_back(
MatchC(pattern_id=ent_id, start=state.start, length=state.length)) MatchC(pattern_id=ent_id, start=state.start, length=state.length))
break break

View File

@ -12,7 +12,6 @@ def get_rule_id(nlp, matcher, doc):
return rule_id return rule_id
@pytest.mark.xfail
def test_issue2671(): def test_issue2671():
nlp = English() nlp = English()
matcher = Matcher(nlp.vocab) matcher = Matcher(nlp.vocab)