Add resume_training function

This commit is contained in:
Matthew Honnibal 2017-09-20 19:15:20 -05:00
parent f5144f04be
commit b832f89ff8

View File

@ -342,7 +342,28 @@ class Language(object):
for doc, gold in docs_golds: for doc, gold in docs_golds:
yield doc, gold yield doc, gold
def begin_training(self, get_gold_tuples, **cfg): def resume_training(self, **cfg):
if cfg.get('device', -1) >= 0:
device = util.use_gpu(cfg['device'])
if self.vocab.vectors.data.shape[1] >= 1:
self.vocab.vectors.data = Model.ops.asarray(
self.vocab.vectors.data)
else:
device = None
learn_rate = util.env_opt('learn_rate', 0.001)
beta1 = util.env_opt('optimizer_B1', 0.9)
beta2 = util.env_opt('optimizer_B2', 0.999)
eps = util.env_opt('optimizer_eps', 1e-08)
L2 = util.env_opt('L2_penalty', 1e-6)
max_grad_norm = util.env_opt('grad_norm_clip', 1.)
self._optimizer = Adam(Model.ops, learn_rate, L2=L2, beta1=beta1,
beta2=beta2, eps=eps)
self._optimizer.max_grad_norm = max_grad_norm
self._optimizer.device = device
return self._optimizer
def begin_training(self, get_gold_tuples=None, **cfg):
"""Allocate models, pre-process training data and acquire a trainer and """Allocate models, pre-process training data and acquire a trainer and
optimizer. Used as a contextmanager. optimizer. Used as a contextmanager.
@ -353,17 +374,14 @@ class Language(object):
if self.parser: if self.parser:
self.pipeline.append(NeuralLabeller(self.vocab)) self.pipeline.append(NeuralLabeller(self.vocab))
# Populate vocab # Populate vocab
for _, annots_brackets in get_gold_tuples(): if get_gold_tuples is not None:
for annots, _ in annots_brackets: for _, annots_brackets in get_gold_tuples():
for word in annots[1]: for annots, _ in annots_brackets:
_ = self.vocab[word] for word in annots[1]:
_ = self.vocab[word]
contexts = [] contexts = []
if cfg.get('device', -1) >= 0: if cfg.get('device', -1) >= 0:
import cupy.cuda.device device = util.use_gpu(cfg['device'])
device = cupy.cuda.device.Device(cfg['device'])
device.use()
Model.ops = CupyOps()
Model.Ops = CupyOps
if self.vocab.vectors.data.shape[1] >= 1: if self.vocab.vectors.data.shape[1] >= 1:
self.vocab.vectors.data = Model.ops.asarray( self.vocab.vectors.data = Model.ops.asarray(
self.vocab.vectors.data) self.vocab.vectors.data)