diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 738cd8f5d..dd8e0b55c 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -8,9 +8,13 @@ from cymem.cymem cimport Pool from preshed.maps cimport PreshMap from libcpp.vector cimport vector from libcpp.pair cimport pair +from libcpp.unordered_map cimport unordered_map as umap +from cython.operator cimport dereference as deref from murmurhash.mrmr cimport hash64 from libc.stdint cimport int32_t +from libc.stdio cimport printf + from .typedefs cimport attr_t from .typedefs cimport hash_t from .structs cimport TokenC @@ -85,6 +89,11 @@ cdef struct TokenPatternC: ctypedef TokenPatternC* TokenPatternC_ptr ctypedef pair[int, TokenPatternC_ptr] StateC +# Match Dictionary entry type +cdef struct MatchEntryC: + int32_t start + int32_t end + int32_t offset cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) except NULL: @@ -336,8 +345,11 @@ cdef class Matcher: cdef int j = 0 cdef int k cdef bint add_match,overlap = False + cdef TokenPatternC_ptr final_state + cdef umap[TokenPatternC_ptr,MatchEntryC] matches_dict + cdef umap[TokenPatternC_ptr,MatchEntryC].iterator state_match + cdef MatchEntryC new_match matches = [] - matches_dict = {} for token_i in range(doc.length): token = &doc.c[token_i] q = 0 @@ -350,8 +362,18 @@ cdef class Matcher: action = get_action(state.second, token) j += 1 # Skip patterns that would overlap with an existing match - ent_id = get_pattern_key(state.second) - if ent_id in matches_dict and state.first>matches_dict[ent_id][0] and state.firstderef(state_match).second.start + and state.first= matches_dict[ent_id][1]: - matches_dict[ent_id] = (start,end,len(matches)) + elif start >= deref(state_match).second.end: + new_match.start = start + new_match.end = end + new_match.offset = len(matches) + matches_dict[final_state] = new_match matches.append((ent_id,start,end)) - elif start <= matches_dict[ent_id][0] and end>=matches_dict[ent_id][1]: - i = matches_dict[ent_id][2] + elif start <= deref(state_match).second.start and end>=deref(state_match).second.end: + i = deref(state_match).second.offset matches[i] = (ent_id,start,end) - matches_dict[ent_id] = (start,end,i) + new_match.start = start + new_match.end = end + new_match.offset = i + matches_dict[final_state] = new_match else: pass @@ -438,7 +471,13 @@ cdef class Matcher: for pattern in self.patterns: # Skip patterns that would overlap with an existing match ent_id = get_pattern_key(pattern) - if ent_id in matches_dict and token_i>matches_dict[ent_id][0] and token_ideref(state_match).second.start + and token_i= matches_dict[ent_id][1]: - matches_dict[ent_id] = (start,end,len(matches)) + elif start >= deref(state_match).second.end: + new_match.start = start + new_match.end = end + new_match.offset = len(matches) + matches_dict[final_state] = new_match matches.append((ent_id,start,end)) - elif start <= matches_dict[ent_id][0] and end>=matches_dict[ent_id][1]: - j = matches_dict[ent_id][2] + elif start <= deref(state_match).second.start and end>=deref(state_match).second.end: + j = deref(state_match).second.offset matches[j] = (ent_id,start,end) - matches_dict[ent_id] = (start,end,j) + new_match.start = start + new_match.end = end + new_match.offset = j + matches_dict[final_state] = new_match else: pass @@ -503,16 +554,27 @@ cdef class Matcher: end = len(doc) ent_id = state.second.attrs[0].value label = state.second.attrs[1].value - if ent_id not in matches_dict: - matches_dict[ent_id] = (start,end,len(matches)) + final_state = state.second + state_match = matches_dict.find(final_state) + if state_match == matches_dict.end(): + new_match.start = start + new_match.end = end + new_match.offset = len(matches) + matches_dict[final_state] = new_match matches.append((ent_id,start,end)) - elif start >= matches_dict[ent_id][1]: - matches_dict[ent_id] = (start,end,len(matches)) + elif start >= deref(state_match).second.end: + new_match.start = start + new_match.end = end + new_match.offset = len(matches) + matches_dict[final_state] = new_match matches.append((ent_id,start,end)) - elif start <= matches_dict[ent_id][0] and end>=matches_dict[ent_id][1]: - j = matches_dict[ent_id][2] + elif start <= deref(state_match).second.start and end>=deref(state_match).second.end: + j = deref(state_match).second.offset matches[j] = (ent_id,start,end) - matches_dict[ent_id] = (start,end,j) + new_match.start = start + new_match.end = end + new_match.offset = j + matches_dict[final_state] = new_match else: pass for i, (ent_id, start, end) in enumerate(matches):