diff --git a/spacy/tests/test_models.py b/spacy/tests/test_models.py index ce074fe42..794f9ca87 100644 --- a/spacy/tests/test_models.py +++ b/spacy/tests/test_models.py @@ -2,11 +2,13 @@ from typing import List import pytest from thinc.api import fix_random_seed, Adam, set_dropout_rate from thinc.api import Ragged, reduce_mean, Logistic, chain, Relu +from thinc.util import has_torch from numpy.testing import assert_array_equal, assert_array_almost_equal import numpy from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier -from spacy.ml.models import build_spancat_model, build_wl_coref_model +if has_torch: + from spacy.ml.models import build_spancat_model, build_wl_coref_model from spacy.ml.staticvectors import StaticVectors from spacy.ml.extract_spans import extract_spans, _get_span_indices from spacy.lang.en import English @@ -271,6 +273,7 @@ def test_spancat_model_forward_backward(nO=5): backprop(Y) #TODO expand this +@pytest.mark.skipif(not has_torch, reason="Torch not available") def test_coref_model_init(): tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs()) model = build_wl_coref_model(tok2vec)