mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 20:52:23 +03:00
Skip coref test if no torch
This commit is contained in:
parent
3807a1ba74
commit
303269c4b2
|
@ -2,11 +2,13 @@ from typing import List
|
||||||
import pytest
|
import pytest
|
||||||
from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
||||||
from thinc.api import Ragged, reduce_mean, Logistic, chain, Relu
|
from thinc.api import Ragged, reduce_mean, Logistic, chain, Relu
|
||||||
|
from thinc.util import has_torch
|
||||||
from numpy.testing import assert_array_equal, assert_array_almost_equal
|
from numpy.testing import assert_array_equal, assert_array_almost_equal
|
||||||
import numpy
|
import numpy
|
||||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||||
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
|
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
|
||||||
from spacy.ml.models import build_spancat_model, build_wl_coref_model
|
if has_torch:
|
||||||
|
from spacy.ml.models import build_spancat_model, build_wl_coref_model
|
||||||
from spacy.ml.staticvectors import StaticVectors
|
from spacy.ml.staticvectors import StaticVectors
|
||||||
from spacy.ml.extract_spans import extract_spans, _get_span_indices
|
from spacy.ml.extract_spans import extract_spans, _get_span_indices
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
@ -271,6 +273,7 @@ def test_spancat_model_forward_backward(nO=5):
|
||||||
backprop(Y)
|
backprop(Y)
|
||||||
|
|
||||||
#TODO expand this
|
#TODO expand this
|
||||||
|
@pytest.mark.skipif(not has_torch, reason="Torch not available")
|
||||||
def test_coref_model_init():
|
def test_coref_model_init():
|
||||||
tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs())
|
tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||||
model = build_wl_coref_model(tok2vec)
|
model = build_wl_coref_model(tok2vec)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user