From ce512e1d47dd1adbacf0b7ee1d2e0cb0a5446cb1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 15 Aug 2018 16:19:08 +0200 Subject: [PATCH] Fix #2671: Incorrect match ID on some patterns --- spacy/matcher.pyx | 14 ++++++++++++-- spacy/tests/regression/test_issue2671.py | 1 - 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index a5ab350f0..970cb8743 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -90,6 +90,16 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc): for i in range(matches.size())] +cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil: + # The code was originally designed to always have pattern[1].attrs.value + # be the ent_id when we get to the end of a pattern. However, Issue #2671 + # showed this wasn't the case when we had a reject-and-continue before a + # match. I still don't really understand what's going on here, but this + # workaround does resolve the issue. + while pattern.attrs.attr != ID and pattern.nr_attr > 0: + pattern += 1 + return pattern.attrs.value + cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, const TokenC* token, const attr_t* extra_attrs) except *: @@ -115,7 +125,7 @@ cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& match states[q].length += 1 q += 1 else: - ent_id = state.pattern[1].attrs.value + ent_id = get_ent_id(&state.pattern[1]) if action == MATCH: matches.push_back( MatchC(pattern_id=ent_id, start=state.start, @@ -143,7 +153,7 @@ cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE): is_final = get_is_final(state) if is_final: - ent_id = state.pattern[1].attrs.value + ent_id = get_ent_id(state.pattern) matches.push_back( MatchC(pattern_id=ent_id, start=state.start, length=state.length)) break diff --git a/spacy/tests/regression/test_issue2671.py b/spacy/tests/regression/test_issue2671.py index d5c62940d..1b7e04c7c 100644 --- a/spacy/tests/regression/test_issue2671.py +++ b/spacy/tests/regression/test_issue2671.py @@ -12,7 +12,6 @@ def get_rule_id(nlp, matcher, doc): return rule_id -@pytest.mark.xfail def test_issue2671(): nlp = English() matcher = Matcher(nlp.vocab)