mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 21:57:15 +03:00
Enable GPU in pytorch n use_gpu functon
This commit is contained in:
parent
9dbaea1ab4
commit
63ff233ba2
|
@ -707,6 +707,11 @@ def use_gpu(gpu_id):
|
|||
device.use()
|
||||
Model.ops = CupyOps()
|
||||
Model.Ops = CupyOps
|
||||
try:
|
||||
import torch
|
||||
torch.set_default_tensor_type("torch.cuda.FloatTensor")
|
||||
except ImportError:
|
||||
pass
|
||||
return device
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user