diff --git a/spacy/matcher2.pyx b/spacy/matcher2.pyx index 98ac92b84..35f6eecf8 100644 --- a/spacy/matcher2.pyx +++ b/spacy/matcher2.pyx @@ -1,6 +1,7 @@ # cython: infer_types=True +# cython: profile=True from libcpp.vector cimport vector -from libc.stdint cimport int32_t, uint64_t +from libc.stdint cimport int32_t, uint64_t, uint16_t from preshed.maps cimport PreshMap from cymem.cymem cimport Pool from murmurhash.mrmr cimport hash64 @@ -41,6 +42,15 @@ from .attrs import FLAG36 as L9_ENT from .attrs import FLAG35 as L10_ENT +cdef enum action_t: + REJECT = 0000 + MATCH = 1000 + ADVANCE = 0100 + RETRY = 0010 + RETRY_EXTEND = 0011 + MATCH_EXTEND = 1001 + MATCH_REJECT = 2000 + cdef enum quantifier_t: ZERO @@ -82,39 +92,18 @@ cdef struct MatchC: cdef find_matches(TokenPatternC** patterns, int n, Doc doc): - cdef vector[PatternStateC] init_states - cdef ActionC null_action = ActionC(-1, -1, -1, -1) - for i in range(n): - init_states.push_back(PatternStateC(patterns[i], -1, 0)) - cdef vector[PatternStateC] curr_states - cdef vector[PatternStateC] nexts + cdef vector[PatternStateC] states cdef vector[MatchC] matches - cdef PreshMap cache cdef Pool mem = Pool() # TODO: Prefill this with the extra attribute values. extra_attrs = mem.alloc(len(doc), sizeof(attr_t*)) + # Main loop for i in range(doc.length): - nexts.clear() - cache = PreshMap() - for j in range(curr_states.size()): - transition(matches, nexts, - curr_states[j], i, &doc.c[i], extra_attrs[i], cache) - for j in range(init_states.size()): - transition(matches, nexts, - init_states[j], i, &doc.c[i], extra_attrs[i], cache) - nexts, curr_states = curr_states, nexts - # Handle patterns that end with zero-width - for j in range(curr_states.size()): - state = curr_states[j] - while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE): - is_final = get_is_final(state) - if is_final: - ent_id = state.pattern[1].attrs.value - matches.push_back( - MatchC(pattern_id=ent_id, start=state.start, length=state.length)) - break - else: - state.pattern += 1 + for j in range(n): + states.push_back(PatternStateC(patterns[j], i, 0)) + transition_states(states, matches, &doc.c[i], extra_attrs[i]) + # Handle matches that end in 0-width patterns + finish_states(matches, states) # Filter out matches that have a longer equivalent. longest_matches = {} for i in range(matches.size()): @@ -126,37 +115,67 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc): for (pattern_id, start), length in longest_matches.items()] -cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts, - PatternStateC state, int i, const TokenC* token, const attr_t* extra_attrs, - PreshMap cache) except *: - action = get_action(state, token, extra_attrs, cache) - if state.start == -1: - state.start = i - if action.emit_match == 1: - ent_id = state.pattern[1].attrs.value - matches.push_back( - MatchC(pattern_id=ent_id, start=state.start, length=state.length+1)) - elif action.emit_match == 2: - ent_id = state.pattern[1].attrs.value - matches.push_back( - MatchC(pattern_id=ent_id, start=state.start, length=state.length)) - if action.next_state_next_token: - nexts.push_back(PatternStateC(start=state.start, - pattern=&state.pattern[1], length=state.length+1)) - if action.same_state_next_token: - nexts.push_back(PatternStateC(start=state.start, - pattern=state.pattern, length=state.length+1)) - cdef PatternStateC next_state - if action.next_state_same_token: - # 0+ and ? non-matches need to not consume a token, so we call transition - # with the same state - next_state = PatternStateC(start=state.start, pattern=&state.pattern[1], - length=state.length) - transition(matches, nexts, next_state, i, token, extra_attrs, cache) +cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, + const TokenC* token, const attr_t* extra_attrs) except *: + cdef int q = 0 + cdef vector[PatternStateC] new_states + for i in range(states.size()): + action = get_action(states[i], token, extra_attrs) + if action == REJECT: + continue + state = states[i] + states[q] = state + while action in (RETRY, RETRY_EXTEND): + if action == RETRY_EXTEND: + new_states.push_back( + PatternStateC(pattern=state.pattern, start=state.start, + length=state.length+1)) + states[q].pattern += 1 + action = get_action(states[q], token, extra_attrs) + if action == REJECT: + pass + elif action == ADVANCE: + states[q].pattern += 1 + states[q].length += 1 + q += 1 + else: + ent_id = state.pattern[1].attrs.value + if action == MATCH: + matches.push_back( + MatchC(pattern_id=ent_id, start=state.start, + length=state.length+1)) + elif action == MATCH_REJECT: + matches.push_back( + MatchC(pattern_id=ent_id, start=state.start, + length=state.length)) + elif action == MATCH_EXTEND: + matches.push_back( + MatchC(pattern_id=ent_id, start=state.start, + length=state.length)) + states[q].length += 1 + q += 1 + states.resize(q) + for i in range(new_states.size()): + states.push_back(new_states[i]) -cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs, - PreshMap cache) except *: +cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *: + '''Handle states that end in zero-width patterns.''' + cdef PatternStateC state + for i in range(states.size()): + state = states[i] + while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE): + is_final = get_is_final(state) + if is_final: + ent_id = state.pattern[1].attrs.value + matches.push_back( + MatchC(pattern_id=ent_id, start=state.start, length=state.length)) + break + else: + state.pattern += 1 + + +cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) except *: '''We need to consider: a) Does the token match the specification? [Yes, No] @@ -201,18 +220,21 @@ cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* No, non-final: 0010 + Possible combinations: 1000, 0100, 0000, 1001, 0011, 0010, + + We'll name the bits "match", "advance", "retry", "extend" + REJECT = 0000 + MATCH = 1000 + ADVANCE = 0100 + RETRY = 0010 + MATCH_EXTEND = 1001 + RETRY_EXTEND = 0011 + MATCH_REJECT = 2000 # Match, but don't include last token + Problem: If a quantifier is matching, we're adding a lot of open partials ''' - #cached_match = cache.get(state.pattern.key) cdef char is_match - #if cached_match == 0: is_match = get_is_match(state, token, extra_attrs) - # cached_match = is_match + 1 - # cache.set(state.pattern.key, cached_match) - #elif cached_match == 1: - # is_match = 0 - #else: - # is_match = 1 quantifier = get_quantifier(state) is_final = get_is_final(state) if quantifier == ZERO: @@ -221,46 +243,41 @@ cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* if quantifier == ONE: if is_match and is_final: # Yes, final: 1000 - return ActionC(1, 0, 0, 0) + return MATCH elif is_match and not is_final: # Yes, non-final: 0100 - return ActionC(0, 1, 0, 0) + return ADVANCE elif not is_match and is_final: # No, final: 0000 - return ActionC(0, 0, 0, 0) + return REJECT else: - # No, non-final 0000 - return ActionC(0, 0, 0, 0) - + return REJECT elif quantifier == ZERO_PLUS: if is_match and is_final: # Yes, final: 1001 - return ActionC(1, 0, 0, 1) + return MATCH_EXTEND elif is_match and not is_final: # Yes, non-final: 0011 - return ActionC(0, 0, 1, 1) + return RETRY_EXTEND elif not is_match and is_final: - # No, final 1000 (note: Don't include last token!) - return ActionC(2, 0, 0, 0) + # No, final 2000 (note: Don't include last token!) + return MATCH_REJECT else: # No, non-final 0010 - return ActionC(0, 0, 1, 0) + return RETRY elif quantifier == ZERO_ONE: if is_match and is_final: # Yes, final: 1000 - return ActionC(1, 0, 0, 0) + return MATCH elif is_match and not is_final: # Yes, non-final: 0100 - return ActionC(0, 1, 0, 0) + return ADVANCE elif not is_match and is_final: - # No, final 1000 (note: Don't include last token!) - return ActionC(2, 0, 0, 0) + # No, final 2000 (note: Don't include last token!) + return MATCH_REJECT else: # No, non-final 0010 - return ActionC(0, 0, 1, 0) - else: - print(quantifier, is_match, is_final) - raise ValueError + return RETRY cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil: