mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Add resume_training function
This commit is contained in:
parent
f5144f04be
commit
b832f89ff8
|
@ -342,7 +342,28 @@ class Language(object):
|
|||
for doc, gold in docs_golds:
|
||||
yield doc, gold
|
||||
|
||||
def begin_training(self, get_gold_tuples, **cfg):
|
||||
def resume_training(self, **cfg):
|
||||
if cfg.get('device', -1) >= 0:
|
||||
device = util.use_gpu(cfg['device'])
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
self.vocab.vectors.data = Model.ops.asarray(
|
||||
self.vocab.vectors.data)
|
||||
else:
|
||||
device = None
|
||||
learn_rate = util.env_opt('learn_rate', 0.001)
|
||||
beta1 = util.env_opt('optimizer_B1', 0.9)
|
||||
beta2 = util.env_opt('optimizer_B2', 0.999)
|
||||
eps = util.env_opt('optimizer_eps', 1e-08)
|
||||
L2 = util.env_opt('L2_penalty', 1e-6)
|
||||
max_grad_norm = util.env_opt('grad_norm_clip', 1.)
|
||||
self._optimizer = Adam(Model.ops, learn_rate, L2=L2, beta1=beta1,
|
||||
beta2=beta2, eps=eps)
|
||||
self._optimizer.max_grad_norm = max_grad_norm
|
||||
self._optimizer.device = device
|
||||
return self._optimizer
|
||||
|
||||
|
||||
def begin_training(self, get_gold_tuples=None, **cfg):
|
||||
"""Allocate models, pre-process training data and acquire a trainer and
|
||||
optimizer. Used as a contextmanager.
|
||||
|
||||
|
@ -353,17 +374,14 @@ class Language(object):
|
|||
if self.parser:
|
||||
self.pipeline.append(NeuralLabeller(self.vocab))
|
||||
# Populate vocab
|
||||
if get_gold_tuples is not None:
|
||||
for _, annots_brackets in get_gold_tuples():
|
||||
for annots, _ in annots_brackets:
|
||||
for word in annots[1]:
|
||||
_ = self.vocab[word]
|
||||
contexts = []
|
||||
if cfg.get('device', -1) >= 0:
|
||||
import cupy.cuda.device
|
||||
device = cupy.cuda.device.Device(cfg['device'])
|
||||
device.use()
|
||||
Model.ops = CupyOps()
|
||||
Model.Ops = CupyOps
|
||||
device = util.use_gpu(cfg['device'])
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
self.vocab.vectors.data = Model.ops.asarray(
|
||||
self.vocab.vectors.data)
|
||||
|
|
Loading…
Reference in New Issue
Block a user