mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 05:07:03 +03:00
Refactor matcher2, hopefully making it faster
This commit is contained in:
parent
00261eea27
commit
7885b92b45
|
@ -1,6 +1,7 @@
|
||||||
# cython: infer_types=True
|
# cython: infer_types=True
|
||||||
|
# cython: profile=True
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libc.stdint cimport int32_t, uint64_t
|
from libc.stdint cimport int32_t, uint64_t, uint16_t
|
||||||
from preshed.maps cimport PreshMap
|
from preshed.maps cimport PreshMap
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from murmurhash.mrmr cimport hash64
|
from murmurhash.mrmr cimport hash64
|
||||||
|
@ -41,6 +42,15 @@ from .attrs import FLAG36 as L9_ENT
|
||||||
from .attrs import FLAG35 as L10_ENT
|
from .attrs import FLAG35 as L10_ENT
|
||||||
|
|
||||||
|
|
||||||
|
cdef enum action_t:
|
||||||
|
REJECT = 0000
|
||||||
|
MATCH = 1000
|
||||||
|
ADVANCE = 0100
|
||||||
|
RETRY = 0010
|
||||||
|
RETRY_EXTEND = 0011
|
||||||
|
MATCH_EXTEND = 1001
|
||||||
|
MATCH_REJECT = 2000
|
||||||
|
|
||||||
|
|
||||||
cdef enum quantifier_t:
|
cdef enum quantifier_t:
|
||||||
ZERO
|
ZERO
|
||||||
|
@ -82,39 +92,18 @@ cdef struct MatchC:
|
||||||
|
|
||||||
|
|
||||||
cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
||||||
cdef vector[PatternStateC] init_states
|
cdef vector[PatternStateC] states
|
||||||
cdef ActionC null_action = ActionC(-1, -1, -1, -1)
|
|
||||||
for i in range(n):
|
|
||||||
init_states.push_back(PatternStateC(patterns[i], -1, 0))
|
|
||||||
cdef vector[PatternStateC] curr_states
|
|
||||||
cdef vector[PatternStateC] nexts
|
|
||||||
cdef vector[MatchC] matches
|
cdef vector[MatchC] matches
|
||||||
cdef PreshMap cache
|
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
# TODO: Prefill this with the extra attribute values.
|
# TODO: Prefill this with the extra attribute values.
|
||||||
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
|
extra_attrs = <attr_t**>mem.alloc(len(doc), sizeof(attr_t*))
|
||||||
|
# Main loop
|
||||||
for i in range(doc.length):
|
for i in range(doc.length):
|
||||||
nexts.clear()
|
for j in range(n):
|
||||||
cache = PreshMap()
|
states.push_back(PatternStateC(patterns[j], i, 0))
|
||||||
for j in range(curr_states.size()):
|
transition_states(states, matches, &doc.c[i], extra_attrs[i])
|
||||||
transition(matches, nexts,
|
# Handle matches that end in 0-width patterns
|
||||||
curr_states[j], i, &doc.c[i], extra_attrs[i], cache)
|
finish_states(matches, states)
|
||||||
for j in range(init_states.size()):
|
|
||||||
transition(matches, nexts,
|
|
||||||
init_states[j], i, &doc.c[i], extra_attrs[i], cache)
|
|
||||||
nexts, curr_states = curr_states, nexts
|
|
||||||
# Handle patterns that end with zero-width
|
|
||||||
for j in range(curr_states.size()):
|
|
||||||
state = curr_states[j]
|
|
||||||
while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE):
|
|
||||||
is_final = get_is_final(state)
|
|
||||||
if is_final:
|
|
||||||
ent_id = state.pattern[1].attrs.value
|
|
||||||
matches.push_back(
|
|
||||||
MatchC(pattern_id=ent_id, start=state.start, length=state.length))
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
state.pattern += 1
|
|
||||||
# Filter out matches that have a longer equivalent.
|
# Filter out matches that have a longer equivalent.
|
||||||
longest_matches = {}
|
longest_matches = {}
|
||||||
for i in range(matches.size()):
|
for i in range(matches.size()):
|
||||||
|
@ -126,37 +115,67 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc):
|
||||||
for (pattern_id, start), length in longest_matches.items()]
|
for (pattern_id, start), length in longest_matches.items()]
|
||||||
|
|
||||||
|
|
||||||
cdef void transition(vector[MatchC]& matches, vector[PatternStateC]& nexts,
|
cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches,
|
||||||
PatternStateC state, int i, const TokenC* token, const attr_t* extra_attrs,
|
const TokenC* token, const attr_t* extra_attrs) except *:
|
||||||
PreshMap cache) except *:
|
cdef int q = 0
|
||||||
action = get_action(state, token, extra_attrs, cache)
|
cdef vector[PatternStateC] new_states
|
||||||
if state.start == -1:
|
for i in range(states.size()):
|
||||||
state.start = i
|
action = get_action(states[i], token, extra_attrs)
|
||||||
if action.emit_match == 1:
|
if action == REJECT:
|
||||||
|
continue
|
||||||
|
state = states[i]
|
||||||
|
states[q] = state
|
||||||
|
while action in (RETRY, RETRY_EXTEND):
|
||||||
|
if action == RETRY_EXTEND:
|
||||||
|
new_states.push_back(
|
||||||
|
PatternStateC(pattern=state.pattern, start=state.start,
|
||||||
|
length=state.length+1))
|
||||||
|
states[q].pattern += 1
|
||||||
|
action = get_action(states[q], token, extra_attrs)
|
||||||
|
if action == REJECT:
|
||||||
|
pass
|
||||||
|
elif action == ADVANCE:
|
||||||
|
states[q].pattern += 1
|
||||||
|
states[q].length += 1
|
||||||
|
q += 1
|
||||||
|
else:
|
||||||
ent_id = state.pattern[1].attrs.value
|
ent_id = state.pattern[1].attrs.value
|
||||||
|
if action == MATCH:
|
||||||
matches.push_back(
|
matches.push_back(
|
||||||
MatchC(pattern_id=ent_id, start=state.start, length=state.length+1))
|
MatchC(pattern_id=ent_id, start=state.start,
|
||||||
elif action.emit_match == 2:
|
length=state.length+1))
|
||||||
|
elif action == MATCH_REJECT:
|
||||||
|
matches.push_back(
|
||||||
|
MatchC(pattern_id=ent_id, start=state.start,
|
||||||
|
length=state.length))
|
||||||
|
elif action == MATCH_EXTEND:
|
||||||
|
matches.push_back(
|
||||||
|
MatchC(pattern_id=ent_id, start=state.start,
|
||||||
|
length=state.length))
|
||||||
|
states[q].length += 1
|
||||||
|
q += 1
|
||||||
|
states.resize(q)
|
||||||
|
for i in range(new_states.size()):
|
||||||
|
states.push_back(new_states[i])
|
||||||
|
|
||||||
|
|
||||||
|
cdef void finish_states(vector[MatchC]& matches, vector[PatternStateC]& states) except *:
|
||||||
|
'''Handle states that end in zero-width patterns.'''
|
||||||
|
cdef PatternStateC state
|
||||||
|
for i in range(states.size()):
|
||||||
|
state = states[i]
|
||||||
|
while get_quantifier(state) in (ZERO_PLUS, ZERO_ONE):
|
||||||
|
is_final = get_is_final(state)
|
||||||
|
if is_final:
|
||||||
ent_id = state.pattern[1].attrs.value
|
ent_id = state.pattern[1].attrs.value
|
||||||
matches.push_back(
|
matches.push_back(
|
||||||
MatchC(pattern_id=ent_id, start=state.start, length=state.length))
|
MatchC(pattern_id=ent_id, start=state.start, length=state.length))
|
||||||
if action.next_state_next_token:
|
break
|
||||||
nexts.push_back(PatternStateC(start=state.start,
|
else:
|
||||||
pattern=&state.pattern[1], length=state.length+1))
|
state.pattern += 1
|
||||||
if action.same_state_next_token:
|
|
||||||
nexts.push_back(PatternStateC(start=state.start,
|
|
||||||
pattern=state.pattern, length=state.length+1))
|
|
||||||
cdef PatternStateC next_state
|
|
||||||
if action.next_state_same_token:
|
|
||||||
# 0+ and ? non-matches need to not consume a token, so we call transition
|
|
||||||
# with the same state
|
|
||||||
next_state = PatternStateC(start=state.start, pattern=&state.pattern[1],
|
|
||||||
length=state.length)
|
|
||||||
transition(matches, nexts, next_state, i, token, extra_attrs, cache)
|
|
||||||
|
|
||||||
|
|
||||||
cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs,
|
cdef action_t get_action(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) except *:
|
||||||
PreshMap cache) except *:
|
|
||||||
'''We need to consider:
|
'''We need to consider:
|
||||||
|
|
||||||
a) Does the token match the specification? [Yes, No]
|
a) Does the token match the specification? [Yes, No]
|
||||||
|
@ -201,18 +220,21 @@ cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t*
|
||||||
No, non-final:
|
No, non-final:
|
||||||
0010
|
0010
|
||||||
|
|
||||||
|
Possible combinations: 1000, 0100, 0000, 1001, 0011, 0010,
|
||||||
|
|
||||||
|
We'll name the bits "match", "advance", "retry", "extend"
|
||||||
|
REJECT = 0000
|
||||||
|
MATCH = 1000
|
||||||
|
ADVANCE = 0100
|
||||||
|
RETRY = 0010
|
||||||
|
MATCH_EXTEND = 1001
|
||||||
|
RETRY_EXTEND = 0011
|
||||||
|
MATCH_REJECT = 2000 # Match, but don't include last token
|
||||||
|
|
||||||
Problem: If a quantifier is matching, we're adding a lot of open partials
|
Problem: If a quantifier is matching, we're adding a lot of open partials
|
||||||
'''
|
'''
|
||||||
#cached_match = <uint64_t>cache.get(state.pattern.key)
|
|
||||||
cdef char is_match
|
cdef char is_match
|
||||||
#if cached_match == 0:
|
|
||||||
is_match = get_is_match(state, token, extra_attrs)
|
is_match = get_is_match(state, token, extra_attrs)
|
||||||
# cached_match = is_match + 1
|
|
||||||
# cache.set(state.pattern.key, <void*>cached_match)
|
|
||||||
#elif cached_match == 1:
|
|
||||||
# is_match = 0
|
|
||||||
#else:
|
|
||||||
# is_match = 1
|
|
||||||
quantifier = get_quantifier(state)
|
quantifier = get_quantifier(state)
|
||||||
is_final = get_is_final(state)
|
is_final = get_is_final(state)
|
||||||
if quantifier == ZERO:
|
if quantifier == ZERO:
|
||||||
|
@ -221,46 +243,41 @@ cdef ActionC get_action(PatternStateC state, const TokenC* token, const attr_t*
|
||||||
if quantifier == ONE:
|
if quantifier == ONE:
|
||||||
if is_match and is_final:
|
if is_match and is_final:
|
||||||
# Yes, final: 1000
|
# Yes, final: 1000
|
||||||
return ActionC(1, 0, 0, 0)
|
return MATCH
|
||||||
elif is_match and not is_final:
|
elif is_match and not is_final:
|
||||||
# Yes, non-final: 0100
|
# Yes, non-final: 0100
|
||||||
return ActionC(0, 1, 0, 0)
|
return ADVANCE
|
||||||
elif not is_match and is_final:
|
elif not is_match and is_final:
|
||||||
# No, final: 0000
|
# No, final: 0000
|
||||||
return ActionC(0, 0, 0, 0)
|
return REJECT
|
||||||
else:
|
else:
|
||||||
# No, non-final 0000
|
return REJECT
|
||||||
return ActionC(0, 0, 0, 0)
|
|
||||||
|
|
||||||
elif quantifier == ZERO_PLUS:
|
elif quantifier == ZERO_PLUS:
|
||||||
if is_match and is_final:
|
if is_match and is_final:
|
||||||
# Yes, final: 1001
|
# Yes, final: 1001
|
||||||
return ActionC(1, 0, 0, 1)
|
return MATCH_EXTEND
|
||||||
elif is_match and not is_final:
|
elif is_match and not is_final:
|
||||||
# Yes, non-final: 0011
|
# Yes, non-final: 0011
|
||||||
return ActionC(0, 0, 1, 1)
|
return RETRY_EXTEND
|
||||||
elif not is_match and is_final:
|
elif not is_match and is_final:
|
||||||
# No, final 1000 (note: Don't include last token!)
|
# No, final 2000 (note: Don't include last token!)
|
||||||
return ActionC(2, 0, 0, 0)
|
return MATCH_REJECT
|
||||||
else:
|
else:
|
||||||
# No, non-final 0010
|
# No, non-final 0010
|
||||||
return ActionC(0, 0, 1, 0)
|
return RETRY
|
||||||
elif quantifier == ZERO_ONE:
|
elif quantifier == ZERO_ONE:
|
||||||
if is_match and is_final:
|
if is_match and is_final:
|
||||||
# Yes, final: 1000
|
# Yes, final: 1000
|
||||||
return ActionC(1, 0, 0, 0)
|
return MATCH
|
||||||
elif is_match and not is_final:
|
elif is_match and not is_final:
|
||||||
# Yes, non-final: 0100
|
# Yes, non-final: 0100
|
||||||
return ActionC(0, 1, 0, 0)
|
return ADVANCE
|
||||||
elif not is_match and is_final:
|
elif not is_match and is_final:
|
||||||
# No, final 1000 (note: Don't include last token!)
|
# No, final 2000 (note: Don't include last token!)
|
||||||
return ActionC(2, 0, 0, 0)
|
return MATCH_REJECT
|
||||||
else:
|
else:
|
||||||
# No, non-final 0010
|
# No, non-final 0010
|
||||||
return ActionC(0, 0, 1, 0)
|
return RETRY
|
||||||
else:
|
|
||||||
print(quantifier, is_match, is_final)
|
|
||||||
raise ValueError
|
|
||||||
|
|
||||||
|
|
||||||
cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
|
cdef char get_is_match(PatternStateC state, const TokenC* token, const attr_t* extra_attrs) nogil:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user