Enable GPU in pytorch n use_gpu functon

This commit is contained in:
Matthw Honnibal 2019-10-06 19:24:21 +02:00
parent 9dbaea1ab4
commit 63ff233ba2

View File

@ -707,6 +707,11 @@ def use_gpu(gpu_id):
device.use()
Model.ops = CupyOps()
Model.Ops = CupyOps
try:
import torch
torch.set_default_tensor_type("torch.cuda.FloatTensor")
except ImportError:
pass
return device