Support bilstm_depth arg in spacy pretrain

This commit is contained in:
Matthw Honnibal 2019-10-06 19:22:26 +02:00
parent 615ebe584f
commit 157d3d769b

View File

@ -35,6 +35,7 @@ from .train import _load_pretrained_tok2vec
output_dir=("Directory to write models to on each epoch", "positional", None, str), output_dir=("Directory to write models to on each epoch", "positional", None, str),
width=("Width of CNN layers", "option", "cw", int), width=("Width of CNN layers", "option", "cw", int),
depth=("Depth of CNN layers", "option", "cd", int), depth=("Depth of CNN layers", "option", "cd", int),
bilstm_depth=("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int),
embed_rows=("Number of embedding rows", "option", "er", int), embed_rows=("Number of embedding rows", "option", "er", int),
loss_func=( loss_func=(
"Loss function to use for the objective. Either 'L2' or 'cosine'", "Loss function to use for the objective. Either 'L2' or 'cosine'",
@ -80,6 +81,7 @@ def pretrain(
output_dir, output_dir,
width=96, width=96,
depth=4, depth=4,
bilstm_depth=2,
embed_rows=2000, embed_rows=2000,
loss_func="cosine", loss_func="cosine",
use_vectors=False, use_vectors=False,
@ -116,6 +118,10 @@ def pretrain(
util.fix_random_seed(seed) util.fix_random_seed(seed)
has_gpu = prefer_gpu() has_gpu = prefer_gpu()
if has_gpu:
import torch
torch.set_default_tensor_type("torch.cuda.FloatTensor")
msg.info("Using GPU" if has_gpu else "Not using GPU") msg.info("Using GPU" if has_gpu else "Not using GPU")
output_dir = Path(output_dir) output_dir = Path(output_dir)
@ -151,7 +157,7 @@ def pretrain(
embed_rows, embed_rows,
conv_depth=depth, conv_depth=depth,
pretrained_vectors=pretrained_vectors, pretrained_vectors=pretrained_vectors,
bilstm_depth=0, # Requires PyTorch. Experimental. bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental.
cnn_maxout_pieces=3, # You can try setting this higher cnn_maxout_pieces=3, # You can try setting this higher
subword_features=True, # Set to False for Chinese etc subword_features=True, # Set to False for Chinese etc
), ),