Make parser update less hacky

This commit is contained in:
Matthew Honnibal 2017-05-25 06:49:00 -05:00
parent 8500d9b1da
commit 679efe79c8

View File

@ -438,7 +438,7 @@ cdef class Parser:
backprops = []
d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
cdef float loss = 0.
while len(todo) >= 2:
while todo:
states, golds = zip(*todo)
token_ids = self.get_token_ids(states)
@ -465,15 +465,10 @@ cdef class Parser:
backprops.append((token_ids, d_vector, bp_vector))
self.transition_batch(states, scores)
todo = [st for st in todo if not st[0].is_final()]
if len(backprops) >= 20:
self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream)
backprops = []
if losses is not None:
losses[self.name] += (d_scores**2).sum()
if backprops:
self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream)
self._make_updates(d_tokvecs,
backprops, sgd, cuda_stream)
return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):