mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-15 06:09:01 +03:00
Add option for GPU ID to pretrain
This commit is contained in:
parent
1dce86c555
commit
6c8785a238
|
@ -10,10 +10,11 @@ from collections import Counter
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from thinc.v2v import Affine, Maxout
|
from thinc.v2v import Affine, Maxout
|
||||||
from thinc.misc import LayerNorm as LN
|
from thinc.misc import LayerNorm as LN
|
||||||
from thinc.neural.util import prefer_gpu
|
from thinc.neural.util import require_gpu
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.neural.util import to_categorical
|
from thinc.neural.util import to_categorical
|
||||||
|
from thinc.rates import cyclic_triangular_rate
|
||||||
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
|
@ -80,6 +81,13 @@ from .train import _load_pretrained_tok2vec
|
||||||
"es",
|
"es",
|
||||||
int,
|
int,
|
||||||
),
|
),
|
||||||
|
gpu_id=(
|
||||||
|
"Index of GPU to use, e.g. 0. -1 for CPU.",
|
||||||
|
"option",
|
||||||
|
"gpu",
|
||||||
|
int,
|
||||||
|
),
|
||||||
|
|
||||||
)
|
)
|
||||||
def pretrain(
|
def pretrain(
|
||||||
texts_loc,
|
texts_loc,
|
||||||
|
@ -104,6 +112,7 @@ def pretrain(
|
||||||
n_save_every=None,
|
n_save_every=None,
|
||||||
init_tok2vec=None,
|
init_tok2vec=None,
|
||||||
epoch_start=None,
|
epoch_start=None,
|
||||||
|
gpu_id=-1,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
Pre-train the 'token-to-vector' (tok2vec) layer of pipeline components,
|
||||||
|
@ -126,10 +135,9 @@ def pretrain(
|
||||||
config[key] = str(config[key])
|
config[key] = str(config[key])
|
||||||
msg = Printer()
|
msg = Printer()
|
||||||
util.fix_random_seed(seed)
|
util.fix_random_seed(seed)
|
||||||
|
if gpu_id != -1:
|
||||||
has_gpu = prefer_gpu(gpu_id=1)
|
has_gpu = require_gpu(gpu_id=gpu_id)
|
||||||
msg.info("Using GPU" if has_gpu else "Not using GPU")
|
msg.info("Using GPU {}".format(gpu_id) if has_gpu else "Not using GPU")
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
|
@ -206,7 +214,8 @@ def pretrain(
|
||||||
|
|
||||||
def _save_model(epoch, is_temp=False):
|
def _save_model(epoch, is_temp=False):
|
||||||
is_temp_str = ".temp" if is_temp else ""
|
is_temp_str = ".temp" if is_temp else ""
|
||||||
with model.use_params(optimizer.averages):
|
#with model.use_params(optimizer.averages):
|
||||||
|
if True:
|
||||||
with (output_dir / ("model%d%s.bin" % (epoch, is_temp_str))).open(
|
with (output_dir / ("model%d%s.bin" % (epoch, is_temp_str))).open(
|
||||||
"wb"
|
"wb"
|
||||||
) as file_:
|
) as file_:
|
||||||
|
@ -221,6 +230,10 @@ def pretrain(
|
||||||
file_.write(srsly.json_dumps(log) + "\n")
|
file_.write(srsly.json_dumps(log) + "\n")
|
||||||
|
|
||||||
skip_counter = 0
|
skip_counter = 0
|
||||||
|
min_lr = optimizer.alpha / 3
|
||||||
|
max_lr = optimizer.alpha * 2
|
||||||
|
period = 10000
|
||||||
|
learn_rates = cyclic_triangular_rate(min_lr, max_lr, period)
|
||||||
for epoch in range(epoch_start, n_iter + epoch_start):
|
for epoch in range(epoch_start, n_iter + epoch_start):
|
||||||
for batch_id, batch in enumerate(
|
for batch_id, batch in enumerate(
|
||||||
util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
|
util.minibatch_by_words(((text, None) for text in texts), size=batch_size)
|
||||||
|
@ -232,6 +245,7 @@ def pretrain(
|
||||||
min_length=min_length,
|
min_length=min_length,
|
||||||
)
|
)
|
||||||
skip_counter += count
|
skip_counter += count
|
||||||
|
optimizer.alpha = next(learn_rates)
|
||||||
loss = make_update(
|
loss = make_update(
|
||||||
model, docs, optimizer, objective=loss_func, drop=dropout
|
model, docs, optimizer, objective=loss_func, drop=dropout
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user