Reset is_valid and costs during beam training

This commit is contained in:
Matthew Honnibal 2016-07-31 19:02:45 +02:00
parent 3e46b491b9
commit 2f09b041d1

View File

@ -143,8 +143,8 @@ cdef class ParserNeuralNet(NeuralNet):
def _update_from_history(self, TransitionSystem moves, Doc doc, history, weight_t grad): def _update_from_history(self, TransitionSystem moves, Doc doc, history, weight_t grad):
cdef Pool mem = Pool() cdef Pool mem = Pool()
features = <FeatureC*>mem.alloc(self.nr_feat, sizeof(FeatureC)) features = <FeatureC*>mem.alloc(self.nr_feat, sizeof(FeatureC))
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int)) is_valid = <int*>mem.alloc(moves.n_moves, sizeof(int))
costs = <weight_t*>mem.alloc(self.moves.n_moves, sizeof(weight_t)) costs = <weight_t*>mem.alloc(moves.n_moves, sizeof(weight_t))
stcls = StateClass.init(doc.c, doc.length) stcls = StateClass.init(doc.c, doc.length)
moves.initialize_state(stcls.c) moves.initialize_state(stcls.c)
@ -153,6 +153,9 @@ cdef class ParserNeuralNet(NeuralNet):
key[1] = 0 key[1] = 0
cdef uint64_t clas cdef uint64_t clas
for clas in history: for clas in history:
memset(costs, 0, moves.n_moves * sizeof(costs[0]))
for i in range(moves.n_moves):
is_valid[i] = 1
nr_feat = self._set_featuresC(features, stcls.c) nr_feat = self._set_featuresC(features, stcls.c)
moves.set_valid(is_valid, stcls.c) moves.set_valid(is_valid, stcls.c)
# Update with a sparse gradient: everything's 0, except our class. # Update with a sparse gradient: everything's 0, except our class.
@ -163,11 +166,10 @@ cdef class ParserNeuralNet(NeuralNet):
# We therefore have a key that indicates the current sequence, so that # We therefore have a key that indicates the current sequence, so that
# the model can merge updates that refer to the same state together, # the model can merge updates that refer to the same state together,
# by summing their gradients. # by summing their gradients.
memset(costs, 0, self.moves.n_moves)
costs[clas] = grad costs[clas] = grad
self.updateC(features, self.updateC(features,
nr_feat, costs, is_valid, False, key=key[0]) nr_feat, costs, is_valid, False, key[0])
moves.c[clas].do(stcls.c, self.moves.c[clas].label) moves.c[clas].do(stcls.c, moves.c[clas].label)
# Build a hash of the state sequence. # Build a hash of the state sequence.
# Position 0 represents the previous sequence, position 1 the new class. # Position 0 represents the previous sequence, position 1 the new class.
# So we want to do: # So we want to do: