Fix GPU selection in spacy train

This commit is contained in:
Matthw Honnibal 2019-10-18 17:23:55 +02:00
parent 49c0adc706
commit f3e2aaea1e

View File

@ -11,6 +11,7 @@ import srsly
from wasabi import Printer
import contextlib
import random
from thinc.neural.util import require_gpu
from .._ml import create_default_optimizer
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
@ -85,6 +86,8 @@ def train(
JSON format. To convert data from other formats, use the `spacy convert`
command.
"""
if use_gpu != -1:
require_gpu(use_gpu)
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm