From f3e2aaea1e551a899a6fd1a97393bdf801c38a19 Mon Sep 17 00:00:00 2001 From: Matthw Honnibal Date: Fri, 18 Oct 2019 17:23:55 +0200 Subject: [PATCH] Fix GPU selection in spacy train --- spacy/cli/train.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 15e19433c..8f93e3d7e 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -11,6 +11,7 @@ import srsly from wasabi import Printer import contextlib import random +from thinc.neural.util import require_gpu from .._ml import create_default_optimizer from ..attrs import PROB, IS_OOV, CLUSTER, LANG @@ -85,6 +86,8 @@ def train( JSON format. To convert data from other formats, use the `spacy convert` command. """ + if use_gpu != -1: + require_gpu(use_gpu) # temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200 import tqdm