mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-15 06:09:01 +03:00
Support bilstm_depth arg in spacy pretrain
This commit is contained in:
parent
615ebe584f
commit
157d3d769b
|
@ -35,6 +35,7 @@ from .train import _load_pretrained_tok2vec
|
||||||
output_dir=("Directory to write models to on each epoch", "positional", None, str),
|
output_dir=("Directory to write models to on each epoch", "positional", None, str),
|
||||||
width=("Width of CNN layers", "option", "cw", int),
|
width=("Width of CNN layers", "option", "cw", int),
|
||||||
depth=("Depth of CNN layers", "option", "cd", int),
|
depth=("Depth of CNN layers", "option", "cd", int),
|
||||||
|
bilstm_depth=("Depth of BiLSTM layers (requires PyTorch)", "option", "lstm", int),
|
||||||
embed_rows=("Number of embedding rows", "option", "er", int),
|
embed_rows=("Number of embedding rows", "option", "er", int),
|
||||||
loss_func=(
|
loss_func=(
|
||||||
"Loss function to use for the objective. Either 'L2' or 'cosine'",
|
"Loss function to use for the objective. Either 'L2' or 'cosine'",
|
||||||
|
@ -80,6 +81,7 @@ def pretrain(
|
||||||
output_dir,
|
output_dir,
|
||||||
width=96,
|
width=96,
|
||||||
depth=4,
|
depth=4,
|
||||||
|
bilstm_depth=2,
|
||||||
embed_rows=2000,
|
embed_rows=2000,
|
||||||
loss_func="cosine",
|
loss_func="cosine",
|
||||||
use_vectors=False,
|
use_vectors=False,
|
||||||
|
@ -116,6 +118,10 @@ def pretrain(
|
||||||
util.fix_random_seed(seed)
|
util.fix_random_seed(seed)
|
||||||
|
|
||||||
has_gpu = prefer_gpu()
|
has_gpu = prefer_gpu()
|
||||||
|
if has_gpu:
|
||||||
|
import torch
|
||||||
|
|
||||||
|
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
||||||
msg.info("Using GPU" if has_gpu else "Not using GPU")
|
msg.info("Using GPU" if has_gpu else "Not using GPU")
|
||||||
|
|
||||||
output_dir = Path(output_dir)
|
output_dir = Path(output_dir)
|
||||||
|
@ -151,7 +157,7 @@ def pretrain(
|
||||||
embed_rows,
|
embed_rows,
|
||||||
conv_depth=depth,
|
conv_depth=depth,
|
||||||
pretrained_vectors=pretrained_vectors,
|
pretrained_vectors=pretrained_vectors,
|
||||||
bilstm_depth=0, # Requires PyTorch. Experimental.
|
bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental.
|
||||||
cnn_maxout_pieces=3, # You can try setting this higher
|
cnn_maxout_pieces=3, # You can try setting this higher
|
||||||
subword_features=True, # Set to False for Chinese etc
|
subword_features=True, # Set to False for Chinese etc
|
||||||
),
|
),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user