diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index d8743d322..1c9b045ac 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -10,7 +10,8 @@ from spacy.ml._precomputable_affine import _backprop_precomputable_affine_paddin from spacy.util import dot_to_object, SimpleFrozenList, import_file from spacy.util import to_ternary_int from thinc.api import Config, Optimizer, ConfigValidationError -from thinc.api import set_current_ops +from thinc.api import get_current_ops, set_current_ops, NumpyOps, CupyOps, MPSOps +from thinc.compat import has_cupy_gpu, has_torch_mps_gpu from spacy.training.batchers import minibatch_by_words from spacy.lang.en import English from spacy.lang.nl import Dutch @@ -18,7 +19,6 @@ from spacy.language import DEFAULT_CONFIG_PATH from spacy.schemas import ConfigSchemaTraining, TokenPattern, TokenPatternSchema from pydantic import ValidationError -from thinc.api import get_current_ops, NumpyOps, CupyOps from .util import get_random_doc, make_tempdir @@ -111,26 +111,25 @@ def test_PrecomputableAffine(nO=4, nI=5, nF=3, nP=2): def test_prefer_gpu(): current_ops = get_current_ops() - try: - import cupy # noqa: F401 - - prefer_gpu() + if has_cupy_gpu: + assert prefer_gpu() assert isinstance(get_current_ops(), CupyOps) - except ImportError: + elif has_torch_mps_gpu: + assert prefer_gpu() + assert isinstance(get_current_ops(), MPSOps) + else: assert not prefer_gpu() set_current_ops(current_ops) def test_require_gpu(): current_ops = get_current_ops() - try: - import cupy # noqa: F401 - + if has_cupy_gpu: require_gpu() assert isinstance(get_current_ops(), CupyOps) - except ImportError: - with pytest.raises(ValueError): - require_gpu() + elif has_torch_mps_gpu: + require_gpu() + assert isinstance(get_current_ops(), MPSOps) set_current_ops(current_ops)