Fix beam update

This commit is contained in:
Matthew Honnibal 2017-08-12 17:15:16 -05:00
parent d4308d2363
commit 4638f4b869
2 changed files with 58 additions and 38 deletions

View File

@ -41,21 +41,24 @@ cdef hash_t _hash_state(void* _state, void* _) except 0:
cdef class ParserBeam(object): cdef class ParserBeam(object):
cdef public TransitionSystem moves cdef public TransitionSystem moves
cdef public object docs cdef public object states
cdef public object golds cdef public object golds
cdef public object beams cdef public object beams
def __init__(self, TransitionSystem moves, docs, golds, def __init__(self, TransitionSystem moves, states, golds,
int width=4, float density=0.001): int width=4, float density=0.001):
self.moves = moves self.moves = moves
self.docs = docs self.states = states
self.golds = golds self.golds = golds
self.beams = [] self.beams = []
cdef Doc doc
cdef Beam beam cdef Beam beam
for doc in docs: cdef StateClass state, st
for state in states:
beam = Beam(self.moves.n_moves, width, density) beam = Beam(self.moves.n_moves, width, density)
beam.initialize(self.moves.init_beam_state, doc.length, doc.c) beam.initialize(self.moves.init_beam_state, state.c.length, state.c._sent)
for i in range(beam.size):
st = <StateClass>beam.at(i)
st.c.offset = state.c.offset
self.beams.append(beam) self.beams.append(beam)
@property @property
@ -100,34 +103,38 @@ cdef class ParserBeam(object):
def get_token_ids(states, int n_tokens): def get_token_ids(states, int n_tokens):
cdef StateClass state cdef StateClass state
cdef np.ndarray ids = numpy.zeros((len(states), n_tokens), cdef np.ndarray ids = numpy.zeros((len(states), n_tokens),
dtype='i', order='C') dtype='int32', order='C')
c_ids = <int*>ids.data c_ids = <int*>ids.data
for i, state in enumerate(states): for i, state in enumerate(states):
if not state.is_final(): if not state.is_final():
state.c.set_context_tokens(c_ids, n_tokens) state.c.set_context_tokens(c_ids, n_tokens)
else:
ids[i] = -1
c_ids += ids.shape[1] c_ids += ids.shape[1]
return ids return ids
def update_beam(TransitionSystem moves, int nr_feature, def update_beam(TransitionSystem moves, int nr_feature, int max_steps,
docs, tokvecs, golds, states, tokvecs, golds,
state2vec, vec2scores, drop=0., sgd=None, state2vec, vec2scores, drop=0., sgd=None,
losses=None, int width=4, float density=0.001): losses=None, int width=4, float density=0.001):
pbeam = ParserBeam(moves, docs, golds, pbeam = ParserBeam(moves, states, golds,
width=width, density=density) width=width, density=density)
gbeam = ParserBeam(moves, docs, golds, gbeam = ParserBeam(moves, states, golds,
width=width, density=density) width=width, density=density)
beam_map = {} beam_maps = []
backprops = [] backprops = []
violns = [MaxViolation() for _ in range(len(docs))] violns = [MaxViolation() for _ in range(len(states))]
example_ids = list(range(len(docs))) for t in range(max_steps):
while not pbeam.is_done and not gbeam.is_done: if pbeam.is_done and gbeam.is_done:
states, p_indices, g_indices = get_states(example_ids, pbeam, gbeam, beam_map) break
beam_maps.append({})
states, p_indices, g_indices = get_states(pbeam, gbeam, beam_maps[-1])
token_ids = get_token_ids(states, nr_feature) token_ids = get_token_ids(states, nr_feature)
vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop) vectors, bp_vectors = state2vec.begin_update(token_ids, drop=drop)
scores, bp_scores = vec2scores.begin_update(vectors, drop=drop) scores, bp_scores = vec2scores.begin_update(vectors, drop=drop)
backprops.append((token_ids, bp_vectors, bp_scores)) backprops.append((token_ids, bp_vectors, bp_scores))
p_scores = [scores[indices] for indices in p_indices] p_scores = [scores[indices] for indices in p_indices]
@ -140,18 +147,18 @@ def update_beam(TransitionSystem moves, int nr_feature,
histories = [(v.p_hist + v.g_hist) for v in violns] histories = [(v.p_hist + v.g_hist) for v in violns]
losses = [(v.p_probs + v.g_probs) for v in violns] losses = [(v.p_probs + v.g_probs) for v in violns]
states_d_scores = get_gradient(moves.n_moves, beam_map, states_d_scores = get_gradient(moves.n_moves, beam_maps,
histories, losses) histories, losses)
return states_d_scores, backprops return states_d_scores, backprops
def get_states(example_ids, pbeams, gbeams, beam_map): def get_states(pbeams, gbeams, beam_map):
states = []
seen = {} seen = {}
states = []
p_indices = [] p_indices = []
g_indices = [] g_indices = []
cdef Beam pbeam, gbeam cdef Beam pbeam, gbeam
for eg_id, pbeam, gbeam in zip(example_ids, pbeams, gbeams): for eg_id, (pbeam, gbeam) in enumerate(zip(pbeams, gbeams)):
p_indices.append([]) p_indices.append([])
for j in range(pbeam.size): for j in range(pbeam.size):
key = tuple([eg_id] + pbeam.histories[j]) key = tuple([eg_id] + pbeam.histories[j])
@ -174,23 +181,30 @@ def get_states(example_ids, pbeams, gbeams, beam_map):
return states, p_indices, g_indices return states, p_indices, g_indices
def get_gradient(nr_class, beam_map, histories, losses): def get_gradient(nr_class, beam_maps, histories, losses):
""" """
The global model assigns a loss to each parse. The beam scores The global model assigns a loss to each parse. The beam scores
are additive, so the same gradient is applied to each action are additive, so the same gradient is applied to each action
in the history. This gives the gradient of a single *action* in the history. This gives the gradient of a single *action*
for a beam state -- so we have "the gradient of loss for taking for a beam state -- so we have "the gradient of loss for taking
action i given history H." action i given history H."
Histories: Each hitory is a list of actions
Each candidate has a history
Each beam has multiple candidates
Each batch has multiple beams
So history is list of lists of lists of ints
""" """
nr_step = max(len(hist) for hist in histories) nr_step = len(beam_maps)
nr_beam = len(histories) grads = [numpy.zeros((max(beam_map.values())+1, nr_class), dtype='f')
grads = [numpy.zeros((nr_beam, nr_class), dtype='f') for _ in range(nr_step)] for beam_map in beam_maps]
for hist, loss in zip(histories, losses): for eg_id, hists in enumerate(histories):
key = tuple() for loss, hist in zip(losses[eg_id], hists):
for j, clas in enumerate(hist): key = tuple([eg_id])
grads[j][i, clas] = loss for j, clas in enumerate(hist):
key = key + clas i = beam_maps[j][key]
i = beam_map[key] grads[j][i, clas] = loss
key = key + tuple([clas])
return grads return grads

View File

@ -529,23 +529,29 @@ cdef class Parser:
def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None): def update_beam(self, docs_tokvecs, golds, drop=0., sgd=None, losses=None):
docs, tokvecs = docs_tokvecs docs, tokvecs = docs_tokvecs
lengths = [len(d) for d in docs]
tokvecs = self.model[0].ops.flatten(tokvecs) tokvecs = self.model[0].ops.flatten(tokvecs)
states, golds, max_moves = self._init_gold_batch(docs, golds)
cuda_stream = get_cuda_stream() cuda_stream = get_cuda_stream()
state2vec, vec2scores = self.get_batch_model(len(docs), tokvecs, cuda_stream, 0.0) state2vec, vec2scores = self.get_batch_model(len(states), tokvecs, cuda_stream, 0.0)
states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, states_d_scores, backprops = _beam_utils.update_beam(self.moves, self.nr_feature, max_moves,
docs, tokvecs, golds, states, tokvecs, golds,
state2vec, vec2scores, state2vec, vec2scores,
drop, sgd, losses) drop, sgd, losses)
backprop_lower = [] backprop_lower = []
for i, d_scores in enumerate(states_d_scores): for i, d_scores in enumerate(states_d_scores):
ids, bp_vectors, bp_scores = backprops[i] ids, bp_vectors, bp_scores = backprops[i]
d_vector = bp_scores(d_scores, sgd=sgd) d_vector = bp_scores(d_scores, sgd=sgd)
backprop_lower.append(( if isinstance(self.model[0].ops, CupyOps) \
get_async(cuda_stream, ids), and not isinstance(ids, state2vec.ops.xp.ndarray):
get_async(cuda_stream, d_vector), backprop_lower.append((
bp_vectors)) get_async(cuda_stream, ids),
get_async(cuda_stream, d_vector),
bp_vectors))
else:
backprop_lower.append((ids, d_vector, bp_vectors))
d_tokvecs = self.model[0].ops.allocate(tokvecs.shape) d_tokvecs = self.model[0].ops.allocate(tokvecs.shape)
self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream) self._make_updates(d_tokvecs, backprop_lower, sgd, cuda_stream)
lengths = [len(doc) for doc in docs] lengths = [len(doc) for doc in docs]