Skip coref test if no torch

This commit is contained in:
Paul O'Leary McCann 2022-05-25 18:26:31 +09:00
parent 3807a1ba74
commit 303269c4b2

View File

@ -2,11 +2,13 @@ from typing import List
import pytest import pytest
from thinc.api import fix_random_seed, Adam, set_dropout_rate from thinc.api import fix_random_seed, Adam, set_dropout_rate
from thinc.api import Ragged, reduce_mean, Logistic, chain, Relu from thinc.api import Ragged, reduce_mean, Logistic, chain, Relu
from thinc.util import has_torch
from numpy.testing import assert_array_equal, assert_array_almost_equal from numpy.testing import assert_array_equal, assert_array_almost_equal
import numpy import numpy
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
from spacy.ml.models import build_spancat_model, build_wl_coref_model if has_torch:
from spacy.ml.models import build_spancat_model, build_wl_coref_model
from spacy.ml.staticvectors import StaticVectors from spacy.ml.staticvectors import StaticVectors
from spacy.ml.extract_spans import extract_spans, _get_span_indices from spacy.ml.extract_spans import extract_spans, _get_span_indices
from spacy.lang.en import English from spacy.lang.en import English
@ -271,6 +273,7 @@ def test_spancat_model_forward_backward(nO=5):
backprop(Y) backprop(Y)
#TODO expand this #TODO expand this
@pytest.mark.skipif(not has_torch, reason="Torch not available")
def test_coref_model_init(): def test_coref_model_init():
tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs()) tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs())
model = build_wl_coref_model(tok2vec) model = build_wl_coref_model(tok2vec)