diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 90a79994e..25673b8c4 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -68,18 +68,18 @@ dropout = null @registry.architectures.register("my_test_parser") def my_parser(): tok2vec = build_Tok2Vec_model( - width=321, - embed_size=5432, - pretrained_vectors=None, - window_size=3, - maxout_pieces=4, - subword_features=True, - char_embed=True, - nM=64, - nC=8, - conv_depth=2, - bilstm_depth=0, - dropout=None, + MultiHashEmbed( + width=321, + embed_size=5432, + also_embed_subwords=True, + also_use_static_vectors=False + ), + MaxoutWindowEncoder( + width=321, + window_size=3, + maxout_pieces=4, + depth=2 + ) ) parser = build_tb_parser_model( tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5 diff --git a/spacy/tests/test_models.py b/spacy/tests/test_models.py index fc1988fcd..4c38ea6c6 100644 --- a/spacy/tests/test_models.py +++ b/spacy/tests/test_models.py @@ -5,12 +5,32 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate from numpy.testing import assert_array_equal import numpy -from spacy.ml.models import build_Tok2Vec_model +from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier from spacy.lang.en import English from spacy.lang.en.examples import sentences as EN_SENTENCES +def get_textcat_kwargs(): + return { + "width": 64, + "embed_size": 2000, + "pretrained_vectors": None, + "exclusive_classes": False, + "ngram_size": 1, + "window_size": 1, + "conv_depth": 2, + "dropout": None, + "nO": 7, + } + +def get_textcat_cnn_kwargs(): + return { + "tok2vec": test_tok2vec(), + "exclusive_classes": False, + "nO": 13, + } + def get_all_params(model): params = [] for node in model.walk(): @@ -35,50 +55,34 @@ def get_gradient(model, Y): raise ValueError(f"Could not get gradient for type {type(Y)}") +def get_tok2vec_kwargs(): + # This actually creates models, so seems best to put it in a function. + return { + "embed": MultiHashEmbed( + width=32, + rows=500, + also_embed_subwords=True, + also_use_static_vectors=False + ), + "encode": MaxoutWindowEncoder( + width=32, + depth=2, + maxout_pieces=2, + window_size=1, + ) + } + + def test_tok2vec(): - return build_Tok2Vec_model(**TOK2VEC_KWARGS) - - -TOK2VEC_KWARGS = { - "width": 96, - "embed_size": 2000, - "subword_features": True, - "char_embed": False, - "conv_depth": 4, - "bilstm_depth": 0, - "maxout_pieces": 4, - "window_size": 1, - "dropout": 0.1, - "nM": 0, - "nC": 0, - "pretrained_vectors": None, -} - -TEXTCAT_KWARGS = { - "width": 64, - "embed_size": 2000, - "pretrained_vectors": None, - "exclusive_classes": False, - "ngram_size": 1, - "window_size": 1, - "conv_depth": 2, - "dropout": None, - "nO": 7, -} - -TEXTCAT_CNN_KWARGS = { - "tok2vec": test_tok2vec(), - "exclusive_classes": False, - "nO": 13, -} + return build_Tok2Vec_model(**get_tok2vec_kwargs()) @pytest.mark.parametrize( "seed,model_func,kwargs", [ - (0, build_Tok2Vec_model, TOK2VEC_KWARGS), - (0, build_text_classifier, TEXTCAT_KWARGS), - (0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS), + (0, build_Tok2Vec_model, get_tok2vec_kwargs()), + (0, build_text_classifier, get_textcat_kwargs()), + (0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()), ], ) def test_models_initialize_consistently(seed, model_func, kwargs): @@ -96,9 +100,9 @@ def test_models_initialize_consistently(seed, model_func, kwargs): @pytest.mark.parametrize( "seed,model_func,kwargs,get_X", [ - (0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs), - (0, build_text_classifier, TEXTCAT_KWARGS, get_docs), - (0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs), + (0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs), + (0, build_text_classifier, get_textcat_kwargs(), get_docs), + (0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs), ], ) def test_models_predict_consistently(seed, model_func, kwargs, get_X): @@ -131,9 +135,9 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X): @pytest.mark.parametrize( "seed,dropout,model_func,kwargs,get_X", [ - (0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs), - (0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs), - (0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs), + (0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs), + (0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs), + (0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs), ], ) def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):