diff --git a/spacy/util.py b/spacy/util.py index c7ce38c3f..032c1741a 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -707,6 +707,11 @@ def use_gpu(gpu_id): device.use() Model.ops = CupyOps() Model.Ops = CupyOps + try: + import torch + torch.set_default_tensor_type("torch.cuda.FloatTensor") + except ImportError: + pass return device