mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 01:34:30 +03:00
Fix parser learning rates
This commit is contained in:
parent
bae59bf92f
commit
ab28f911b4
|
@ -524,6 +524,7 @@ cdef class Parser:
|
||||||
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
|
scores, bp_scores = vec2scores.begin_update(vector, drop=drop)
|
||||||
|
|
||||||
d_scores = self.get_batch_loss(states, golds, scores)
|
d_scores = self.get_batch_loss(states, golds, scores)
|
||||||
|
d_scores /= len(docs)
|
||||||
d_vector = bp_scores(d_scores, sgd=sgd)
|
d_vector = bp_scores(d_scores, sgd=sgd)
|
||||||
if drop != 0:
|
if drop != 0:
|
||||||
d_vector *= mask
|
d_vector *= mask
|
||||||
|
@ -582,7 +583,9 @@ cdef class Parser:
|
||||||
width, density,
|
width, density,
|
||||||
sgd=sgd, drop=drop, losses=losses)
|
sgd=sgd, drop=drop, losses=losses)
|
||||||
backprop_lower = []
|
backprop_lower = []
|
||||||
|
cdef float batch_size = len(docs)
|
||||||
for i, d_scores in enumerate(states_d_scores):
|
for i, d_scores in enumerate(states_d_scores):
|
||||||
|
d_scores /= batch_size
|
||||||
if losses is not None:
|
if losses is not None:
|
||||||
losses[self.name] += (d_scores**2).sum()
|
losses[self.name] += (d_scores**2).sum()
|
||||||
ids, bp_vectors, bp_scores = backprops[i]
|
ids, bp_vectors, bp_scores = backprops[i]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user