mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix gradient in fine_tune
This commit is contained in:
parent
0ae045256d
commit
335fa8b05c
|
@ -377,16 +377,16 @@ def fine_tune(embedding, combine=None):
|
|||
lengths)
|
||||
|
||||
def fine_tune_bwd(d_output, sgd=None):
|
||||
bp_vecs(d_output, sgd=sgd)
|
||||
bp_vecs([d_o * model.d_mix[0] for d_o in d_output], sgd=sgd)
|
||||
flat_grad = model.ops.flatten(d_output)
|
||||
model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum()
|
||||
model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum()
|
||||
sgd(model._mem.weights, model._mem.gradient, key=model.id)
|
||||
return d_output
|
||||
return [d_o * model.d_mix[1] for d_o in d_output]
|
||||
return output, fine_tune_bwd
|
||||
model = wrap(fine_tune_fwd, embedding)
|
||||
model.mix = model._mem.add((model.id, 'mix'), (2,))
|
||||
model.mix.fill(1.)
|
||||
model.mix.fill(0.5)
|
||||
model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix'))
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user