mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Don't randomise pipeline for training, and don't update if no gradient
This commit is contained in:
parent
3d22fcaf0b
commit
73a643d32a
|
@ -212,17 +212,16 @@ class Language(object):
|
||||||
"""
|
"""
|
||||||
tok2vec = self.pipeline[0]
|
tok2vec = self.pipeline[0]
|
||||||
feats = tok2vec.doc2feats(docs)
|
feats = tok2vec.doc2feats(docs)
|
||||||
procs = list(self.pipeline[1:])
|
|
||||||
random.shuffle(procs)
|
|
||||||
grads = {}
|
grads = {}
|
||||||
def get_grads(W, dW, key=None):
|
def get_grads(W, dW, key=None):
|
||||||
grads[key] = (W, dW)
|
grads[key] = (W, dW)
|
||||||
for proc in procs:
|
for proc in self.pipeline[1:]:
|
||||||
if not hasattr(proc, 'update'):
|
if not hasattr(proc, 'update'):
|
||||||
continue
|
continue
|
||||||
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
tokvecses, bp_tokvecses = tok2vec.model.begin_update(feats, drop=drop)
|
||||||
d_tokvecses = proc.update((docs, tokvecses), golds,
|
d_tokvecses = proc.update((docs, tokvecses), golds,
|
||||||
drop=drop, sgd=get_grads, losses=losses)
|
drop=drop, sgd=get_grads, losses=losses)
|
||||||
|
if d_tokvecses is not None:
|
||||||
bp_tokvecses(d_tokvecses, sgd=sgd)
|
bp_tokvecses(d_tokvecses, sgd=sgd)
|
||||||
for key, (W, dW) in grads.items():
|
for key, (W, dW) in grads.items():
|
||||||
sgd(W, dW, key=key)
|
sgd(W, dW, key=key)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user