mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
commit
7a6edeea68
21
spacy/_ml.py
21
spacy/_ml.py
|
@ -359,8 +359,6 @@ def get_token_vectors(tokens_attrs_vectors, drop=0.):
|
|||
def backward(d_output, sgd=None):
|
||||
return (tokens, d_output)
|
||||
return vectors, backward
|
||||
|
||||
|
||||
def fine_tune(embedding, combine=None):
|
||||
if combine is not None:
|
||||
raise NotImplementedError(
|
||||
|
@ -372,22 +370,25 @@ def fine_tune(embedding, combine=None):
|
|||
vecs, bp_vecs = embedding.begin_update(docs, drop=drop)
|
||||
flat_tokvecs = embedding.ops.flatten(tokvecs)
|
||||
flat_vecs = embedding.ops.flatten(vecs)
|
||||
alpha = model.mix
|
||||
minus = 1-model.mix
|
||||
output = embedding.ops.unflatten(
|
||||
(model.mix[0] * flat_vecs + model.mix[1] * flat_tokvecs),
|
||||
lengths)
|
||||
(alpha * flat_tokvecs + minus * flat_vecs), lengths)
|
||||
|
||||
def fine_tune_bwd(d_output, sgd=None):
|
||||
bp_vecs(d_output, sgd=sgd)
|
||||
flat_grad = model.ops.flatten(d_output)
|
||||
model.d_mix[1] += flat_tokvecs.dot(flat_grad.T).sum()
|
||||
model.d_mix[0] += flat_vecs.dot(flat_grad.T).sum()
|
||||
if sgd is not None:
|
||||
model.d_mix += flat_tokvecs.dot(flat_grad.T).sum()
|
||||
model.d_mix += 1-flat_vecs.dot(flat_grad.T).sum()
|
||||
|
||||
bp_vecs([d_o * minus for d_o in d_output], sgd=sgd)
|
||||
d_output = [d_o * alpha for d_o in d_output]
|
||||
sgd(model._mem.weights, model._mem.gradient, key=model.id)
|
||||
model.mix = model.ops.xp.minimum(model.mix, 1.0)
|
||||
return d_output
|
||||
return output, fine_tune_bwd
|
||||
model = wrap(fine_tune_fwd, embedding)
|
||||
model.mix = model._mem.add((model.id, 'mix'), (2,))
|
||||
model.mix.fill(1.)
|
||||
model.mix = model._mem.add((model.id, 'mix'), (1,))
|
||||
model.mix.fill(0.0)
|
||||
model.d_mix = model._mem.add_gradient((model.id, 'd_mix'), (model.id, 'mix'))
|
||||
return model
|
||||
|
||||
|
|
|
@ -94,7 +94,7 @@ def train(cmd, lang, output_dir, train_data, dev_data, n_iter=20, n_sents=0,
|
|||
docs, golds = zip(*batch)
|
||||
nlp.update(docs, golds, sgd=optimizer,
|
||||
drop=next(dropout_rates), losses=losses,
|
||||
update_tensors=True)
|
||||
update_shared=True)
|
||||
pbar.update(sum(len(doc) for doc in docs))
|
||||
|
||||
with nlp.use_params(optimizer.averages):
|
||||
|
|
Loading…
Reference in New Issue
Block a user