mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-24 12:41:23 +03:00
Readd beam search after refactor
This commit is contained in:
parent
36b2c9bdd5
commit
5a0f26be0c
1
setup.py
1
setup.py
|
@ -31,6 +31,7 @@ MOD_NAMES = [
|
||||||
'spacy.tokenizer',
|
'spacy.tokenizer',
|
||||||
'spacy.syntax.nn_parser',
|
'spacy.syntax.nn_parser',
|
||||||
'spacy.syntax._parser_model',
|
'spacy.syntax._parser_model',
|
||||||
|
'spacy.syntax._beam_utils',
|
||||||
'spacy.syntax.nonproj',
|
'spacy.syntax.nonproj',
|
||||||
'spacy.syntax.transition_system',
|
'spacy.syntax.transition_system',
|
||||||
'spacy.syntax.arc_eager',
|
'spacy.syntax.arc_eager',
|
||||||
|
|
6
spacy/syntax/_beam_utils.pxd
Normal file
6
spacy/syntax/_beam_utils.pxd
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
from thinc.typedefs cimport class_t
|
||||||
|
|
||||||
|
# These are passed as callbacks to thinc.search.Beam
|
||||||
|
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1
|
||||||
|
|
||||||
|
cdef int check_final_state(void* _state, void* extra_args) except -1
|
|
@ -15,7 +15,7 @@ from .stateclass cimport StateC, StateClass
|
||||||
|
|
||||||
|
|
||||||
# These are passed as callbacks to thinc.search.Beam
|
# These are passed as callbacks to thinc.search.Beam
|
||||||
cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
cdef int transition_state(void* _dest, void* _src, class_t clas, void* _moves) except -1:
|
||||||
dest = <StateC*>_dest
|
dest = <StateC*>_dest
|
||||||
src = <StateC*>_src
|
src = <StateC*>_src
|
||||||
moves = <const Transition*>_moves
|
moves = <const Transition*>_moves
|
||||||
|
@ -24,12 +24,12 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
||||||
dest.push_hist(clas)
|
dest.push_hist(clas)
|
||||||
|
|
||||||
|
|
||||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
cdef int check_final_state(void* _state, void* extra_args) except -1:
|
||||||
state = <StateC*>_state
|
state = <StateC*>_state
|
||||||
return state.is_final()
|
return state.is_final()
|
||||||
|
|
||||||
|
|
||||||
cdef hash_t _hash_state(void* _state, void* _) except 0:
|
cdef hash_t hash_state(void* _state, void* _) except 0:
|
||||||
state = <StateC*>_state
|
state = <StateC*>_state
|
||||||
if state.is_final():
|
if state.is_final():
|
||||||
return 1
|
return 1
|
||||||
|
@ -37,6 +37,20 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
|
||||||
return state.hash()
|
return state.hash()
|
||||||
|
|
||||||
|
|
||||||
|
def collect_states(beams):
|
||||||
|
cdef StateClass state
|
||||||
|
cdef Beam beam
|
||||||
|
states = []
|
||||||
|
for state_or_beam in beams:
|
||||||
|
if isinstance(state_or_beam, StateClass):
|
||||||
|
states.append(state_or_beam)
|
||||||
|
else:
|
||||||
|
beam = state_or_beam
|
||||||
|
state = StateClass.borrow(<StateC*>beam.at(0))
|
||||||
|
states.append(state)
|
||||||
|
return states
|
||||||
|
|
||||||
|
|
||||||
cdef class ParserBeam(object):
|
cdef class ParserBeam(object):
|
||||||
cdef public TransitionSystem moves
|
cdef public TransitionSystem moves
|
||||||
cdef public object states
|
cdef public object states
|
||||||
|
@ -82,8 +96,8 @@ cdef class ParserBeam(object):
|
||||||
self._set_scores(beam, scores[i])
|
self._set_scores(beam, scores[i])
|
||||||
if self.golds is not None:
|
if self.golds is not None:
|
||||||
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
|
self._set_costs(beam, self.golds[i], follow_gold=follow_gold)
|
||||||
beam.advance(_transition_state, NULL, <void*>self.moves.c)
|
beam.advance(transition_state, NULL, <void*>self.moves.c)
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(check_final_state, NULL)
|
||||||
# This handles the non-monotonic stuff for the parser.
|
# This handles the non-monotonic stuff for the parser.
|
||||||
if beam.is_done and self.golds is not None:
|
if beam.is_done and self.golds is not None:
|
||||||
for j in range(beam.size):
|
for j in range(beam.size):
|
||||||
|
@ -144,15 +158,12 @@ nr_update = 0
|
||||||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||||
states, golds,
|
states, golds,
|
||||||
state2vec, vec2scores,
|
state2vec, vec2scores,
|
||||||
int width, float density, int hist_feats,
|
int width, losses=None, drop=0.):
|
||||||
losses=None, drop=0.):
|
|
||||||
global nr_update
|
global nr_update
|
||||||
cdef MaxViolation violn
|
cdef MaxViolation violn
|
||||||
nr_update += 1
|
nr_update += 1
|
||||||
pbeam = ParserBeam(moves, states, golds,
|
pbeam = ParserBeam(moves, states, golds, width=width)
|
||||||
width=width, density=density)
|
gbeam = ParserBeam(moves, states, golds, width=width)
|
||||||
gbeam = ParserBeam(moves, states, golds,
|
|
||||||
width=width, density=density)
|
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
beam_maps = []
|
beam_maps = []
|
||||||
backprops = []
|
backprops = []
|
||||||
|
@ -177,13 +188,7 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||||
# Now that we have our flat list of states, feed them through the model
|
# Now that we have our flat list of states, feed them through the model
|
||||||
token_ids = get_token_ids(states, nr_feature)
|
token_ids = get_token_ids(states, nr_feature)
|
||||||
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
||||||
if hist_feats:
|
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
||||||
hists = numpy.asarray([st.history[:hist_feats] for st in states],
|
|
||||||
dtype='i')
|
|
||||||
scores, bp_scores = vec2scores.begin_update((vectors, hists),
|
|
||||||
drop=drop)
|
|
||||||
else:
|
|
||||||
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
|
||||||
|
|
||||||
# Store the callbacks for the backward pass
|
# Store the callbacks for the backward pass
|
||||||
backprops.append((token_ids, bp_vectors, bp_scores))
|
backprops.append((token_ids, bp_vectors, bp_scores))
|
||||||
|
@ -291,3 +296,27 @@ def get_gradient(nr_class, beam_maps, histories, losses):
|
||||||
grads[j][i, clas] += loss
|
grads[j][i, clas] += loss
|
||||||
key = key + tuple([clas])
|
key = key + tuple([clas])
|
||||||
return grads
|
return grads
|
||||||
|
|
||||||
|
|
||||||
|
def cleanup_beam(Beam beam):
|
||||||
|
cdef StateC* state
|
||||||
|
# Once parsing has finished, states in beam may not be unique. Is this
|
||||||
|
# correct?
|
||||||
|
seen = set()
|
||||||
|
for i in range(beam.width):
|
||||||
|
addr = <size_t>beam._parents[i].content
|
||||||
|
if addr not in seen:
|
||||||
|
state = <StateC*>addr
|
||||||
|
del state
|
||||||
|
seen.add(addr)
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E023.format(addr=addr, i=i))
|
||||||
|
addr = <size_t>beam._states[i].content
|
||||||
|
if addr not in seen:
|
||||||
|
state = <StateC*>addr
|
||||||
|
del state
|
||||||
|
seen.add(addr)
|
||||||
|
else:
|
||||||
|
raise ValueError(Errors.E023.format(addr=addr, i=i))
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user