Update tests

This commit is contained in:
Matthew Honnibal 2020-07-28 22:43:19 +02:00
parent c35d6282fc
commit 20e9098e3f
2 changed files with 61 additions and 57 deletions

View File

@ -68,18 +68,18 @@ dropout = null
@registry.architectures.register("my_test_parser") @registry.architectures.register("my_test_parser")
def my_parser(): def my_parser():
tok2vec = build_Tok2Vec_model( tok2vec = build_Tok2Vec_model(
width=321, MultiHashEmbed(
embed_size=5432, width=321,
pretrained_vectors=None, embed_size=5432,
window_size=3, also_embed_subwords=True,
maxout_pieces=4, also_use_static_vectors=False
subword_features=True, ),
char_embed=True, MaxoutWindowEncoder(
nM=64, width=321,
nC=8, window_size=3,
conv_depth=2, maxout_pieces=4,
bilstm_depth=0, depth=2
dropout=None, )
) )
parser = build_tb_parser_model( parser = build_tb_parser_model(
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5 tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5

View File

@ -5,12 +5,32 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
import numpy import numpy
from spacy.ml.models import build_Tok2Vec_model from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
from spacy.lang.en import English from spacy.lang.en import English
from spacy.lang.en.examples import sentences as EN_SENTENCES from spacy.lang.en.examples import sentences as EN_SENTENCES
def get_textcat_kwargs():
return {
"width": 64,
"embed_size": 2000,
"pretrained_vectors": None,
"exclusive_classes": False,
"ngram_size": 1,
"window_size": 1,
"conv_depth": 2,
"dropout": None,
"nO": 7,
}
def get_textcat_cnn_kwargs():
return {
"tok2vec": test_tok2vec(),
"exclusive_classes": False,
"nO": 13,
}
def get_all_params(model): def get_all_params(model):
params = [] params = []
for node in model.walk(): for node in model.walk():
@ -35,50 +55,34 @@ def get_gradient(model, Y):
raise ValueError(f"Could not get gradient for type {type(Y)}") raise ValueError(f"Could not get gradient for type {type(Y)}")
def get_tok2vec_kwargs():
# This actually creates models, so seems best to put it in a function.
return {
"embed": MultiHashEmbed(
width=32,
rows=500,
also_embed_subwords=True,
also_use_static_vectors=False
),
"encode": MaxoutWindowEncoder(
width=32,
depth=2,
maxout_pieces=2,
window_size=1,
)
}
def test_tok2vec(): def test_tok2vec():
return build_Tok2Vec_model(**TOK2VEC_KWARGS) return build_Tok2Vec_model(**get_tok2vec_kwargs())
TOK2VEC_KWARGS = {
"width": 96,
"embed_size": 2000,
"subword_features": True,
"char_embed": False,
"conv_depth": 4,
"bilstm_depth": 0,
"maxout_pieces": 4,
"window_size": 1,
"dropout": 0.1,
"nM": 0,
"nC": 0,
"pretrained_vectors": None,
}
TEXTCAT_KWARGS = {
"width": 64,
"embed_size": 2000,
"pretrained_vectors": None,
"exclusive_classes": False,
"ngram_size": 1,
"window_size": 1,
"conv_depth": 2,
"dropout": None,
"nO": 7,
}
TEXTCAT_CNN_KWARGS = {
"tok2vec": test_tok2vec(),
"exclusive_classes": False,
"nO": 13,
}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seed,model_func,kwargs", "seed,model_func,kwargs",
[ [
(0, build_Tok2Vec_model, TOK2VEC_KWARGS), (0, build_Tok2Vec_model, get_tok2vec_kwargs()),
(0, build_text_classifier, TEXTCAT_KWARGS), (0, build_text_classifier, get_textcat_kwargs()),
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS), (0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()),
], ],
) )
def test_models_initialize_consistently(seed, model_func, kwargs): def test_models_initialize_consistently(seed, model_func, kwargs):
@ -96,9 +100,9 @@ def test_models_initialize_consistently(seed, model_func, kwargs):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seed,model_func,kwargs,get_X", "seed,model_func,kwargs,get_X",
[ [
(0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs), (0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
(0, build_text_classifier, TEXTCAT_KWARGS, get_docs), (0, build_text_classifier, get_textcat_kwargs(), get_docs),
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs), (0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
], ],
) )
def test_models_predict_consistently(seed, model_func, kwargs, get_X): def test_models_predict_consistently(seed, model_func, kwargs, get_X):
@ -131,9 +135,9 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"seed,dropout,model_func,kwargs,get_X", "seed,dropout,model_func,kwargs,get_X",
[ [
(0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs), (0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
(0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs), (0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs),
(0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs), (0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
], ],
) )
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X): def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):