Fix test_{prefer,require}_gpu (#11390)

* Fix `test_{prefer,require}_gpu`

These tests assumed that GPUs are only supported with CuPy, but since Thinc 8.1
we also support Metal Performance Shaders.

* test_misc: arrange thinc imports to be together
This commit is contained in:
Daniël de Kok 2022-08-30 14:21:02 +02:00 committed by GitHub
parent 5ae63b1fbd
commit 3f4b4b7b4f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,7 +10,8 @@ from spacy.ml._precomputable_affine import _backprop_precomputable_affine_paddin
from spacy.util import dot_to_object, SimpleFrozenList, import_file
from spacy.util import to_ternary_int
from thinc.api import Config, Optimizer, ConfigValidationError
from thinc.api import set_current_ops
from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps
from thinc.compat import has_cupy_gpu, has_torch_mps_gpu
from spacy.training.batchers import minibatch_by_words
from spacy.lang.en import English
from spacy.lang.nl import Dutch
@ -18,7 +19,6 @@ from spacy.language import DEFAULT_CONFIG_PATH
from spacy.schemas import ConfigSchemaTraining, TokenPattern, TokenPatternSchema
from pydantic import ValidationError
from thinc.api import get_current_ops, NumpyOps, CupyOps
from .util import get_random_doc, make_tempdir
@ -111,26 +111,25 @@ def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2):
def test_prefer_gpu():
current_ops = get_current_ops()
try:
import cupy # noqa: F401
prefer_gpu()
if has_cupy_gpu:
assert prefer_gpu()
assert isinstance(get_current_ops(), CupyOps)
except ImportError:
elif has_torch_mps_gpu:
assert prefer_gpu()
assert isinstance(get_current_ops(), MPSOps)
else:
assert not prefer_gpu()
set_current_ops(current_ops)
def test_require_gpu():
current_ops = get_current_ops()
try:
import cupy # noqa: F401
if has_cupy_gpu:
require_gpu()
assert isinstance(get_current_ops(), CupyOps)
except ImportError:
with pytest.raises(ValueError):
elif has_torch_mps_gpu:
require_gpu()
assert isinstance(get_current_ops(), MPSOps)
set_current_ops(current_ops)