mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-31 19:23:05 +03:00
Make parser update less hacky
This commit is contained in:
parent
8500d9b1da
commit
679efe79c8
|
@ -438,7 +438,7 @@ cdef class Parser:
|
||||||
backprops = []
|
backprops = []
|
||||||
d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
|
d_tokvecs = state2vec.ops.allocate(tokvecs.shape)
|
||||||
cdef float loss = 0.
|
cdef float loss = 0.
|
||||||
while len(todo) >= 2:
|
while todo:
|
||||||
states, golds = zip(*todo)
|
states, golds = zip(*todo)
|
||||||
|
|
||||||
token_ids = self.get_token_ids(states)
|
token_ids = self.get_token_ids(states)
|
||||||
|
@ -465,15 +465,10 @@ cdef class Parser:
|
||||||
backprops.append((token_ids, d_vector, bp_vector))
|
backprops.append((token_ids, d_vector, bp_vector))
|
||||||
self.transition_batch(states, scores)
|
self.transition_batch(states, scores)
|
||||||
todo = [st for st in todo if not st[0].is_final()]
|
todo = [st for st in todo if not st[0].is_final()]
|
||||||
if len(backprops) >= 20:
|
|
||||||
self._make_updates(d_tokvecs,
|
|
||||||
backprops, sgd, cuda_stream)
|
|
||||||
backprops = []
|
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
if backprops:
|
self._make_updates(d_tokvecs,
|
||||||
self._make_updates(d_tokvecs,
|
backprops, sgd, cuda_stream)
|
||||||
backprops, sgd, cuda_stream)
|
|
||||||
return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
return self.model[0].ops.unflatten(d_tokvecs, [len(d) for d in docs])
|
||||||
|
|
||||||
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
|
def _make_updates(self, d_tokvecs, backprops, sgd, cuda_stream=None):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user