mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Enable history features for beam parser
This commit is contained in:
parent
fc06b0a333
commit
ca12764772
|
@ -21,6 +21,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
|||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest.c, moves[clas].label)
|
||||
dest.c.push_hist(clas)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
|
@ -149,7 +150,7 @@ nr_update = 0
|
|||
def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
||||
states, golds,
|
||||
state2vec, vec2scores,
|
||||
int width, float density,
|
||||
int width, float density, int hist_feats,
|
||||
losses=None, drop=0.):
|
||||
global nr_update
|
||||
cdef MaxViolation violn
|
||||
|
@ -180,6 +181,10 @@ def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
|
|||
# Now that we have our flat list of states, feed them through the model
|
||||
token_ids = get_token_ids(states, nr_feature)
|
||||
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
|
||||
if hist_feats:
|
||||
hists = numpy.asarray([st.history[:hist_feats] for st in states], dtype='i')
|
||||
scores, bp_scores = vec2scores.begin_update((vectors, hists), drop=drop)
|
||||
else:
|
||||
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
|
||||
|
||||
# Store the callbacks for the backward pass
|
||||
|
|
|
@ -505,7 +505,12 @@ cdef class Parser:
|
|||
states.append(stcls)
|
||||
token_ids = self.get_token_ids(states)
|
||||
vectors = state2vec(token_ids)
|
||||
scores = vec2scores(vectors)
|
||||
if self.cfg.get('hist_size', 0):
|
||||
hists = numpy.asarray([st.history[:self.cfg['hist_size']]
|
||||
for st in states], dtype='i')
|
||||
scores = vec2scores(vectors, drop=drop)
|
||||
else:
|
||||
scores = vec2scores(vectors, drop=drop)
|
||||
j = 0
|
||||
c_scores = <float*>scores.data
|
||||
for i in range(beam.size):
|
||||
|
@ -537,6 +542,7 @@ cdef class Parser:
|
|||
guess = arg_maxout_if_valid(scores, is_valid, nr_class, nr_piece)
|
||||
action = self.moves.c[guess]
|
||||
action.do(state, action.label)
|
||||
state.push_hist(guess)
|
||||
|
||||
free(is_valid)
|
||||
free(scores)
|
||||
|
@ -634,7 +640,7 @@ cdef class Parser:
|
|||
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, 500,
|
||||
states, golds,
|
||||
state2vec, vec2scores,
|
||||
width, density,
|
||||
width, density, self.cfg.get('hist_size', 0),
|
||||
drop=drop, losses=losses)
|
||||
backprop_lower = []
|
||||
cdef float batch_size = len(docs)
|
||||
|
@ -967,6 +973,7 @@ cdef int _transition_state(void* _dest, void* _src, class_t clas, void* _moves)
|
|||
moves = <const Transition*>_moves
|
||||
dest.clone(src)
|
||||
moves[clas].do(dest.c, moves[clas].label)
|
||||
dest.c.push_hist(clas)
|
||||
|
||||
|
||||
cdef int _check_final_state(void* _state, void* extra_args) except -1:
|
||||
|
|
Loading…
Reference in New Issue
Block a user