mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Update tests
This commit is contained in:
parent
c35d6282fc
commit
20e9098e3f
|
@ -68,18 +68,18 @@ dropout = null
|
|||
@registry.architectures.register("my_test_parser")
|
||||
def my_parser():
|
||||
tok2vec = build_Tok2Vec_model(
|
||||
width=321,
|
||||
embed_size=5432,
|
||||
pretrained_vectors=None,
|
||||
window_size=3,
|
||||
maxout_pieces=4,
|
||||
subword_features=True,
|
||||
char_embed=True,
|
||||
nM=64,
|
||||
nC=8,
|
||||
conv_depth=2,
|
||||
bilstm_depth=0,
|
||||
dropout=None,
|
||||
MultiHashEmbed(
|
||||
width=321,
|
||||
embed_size=5432,
|
||||
also_embed_subwords=True,
|
||||
also_use_static_vectors=False
|
||||
),
|
||||
MaxoutWindowEncoder(
|
||||
width=321,
|
||||
window_size=3,
|
||||
maxout_pieces=4,
|
||||
depth=2
|
||||
)
|
||||
)
|
||||
parser = build_tb_parser_model(
|
||||
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
||||
|
|
|
@ -5,12 +5,32 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
|||
from numpy.testing import assert_array_equal
|
||||
import numpy
|
||||
|
||||
from spacy.ml.models import build_Tok2Vec_model
|
||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||
|
||||
|
||||
def get_textcat_kwargs():
|
||||
return {
|
||||
"width": 64,
|
||||
"embed_size": 2000,
|
||||
"pretrained_vectors": None,
|
||||
"exclusive_classes": False,
|
||||
"ngram_size": 1,
|
||||
"window_size": 1,
|
||||
"conv_depth": 2,
|
||||
"dropout": None,
|
||||
"nO": 7,
|
||||
}
|
||||
|
||||
def get_textcat_cnn_kwargs():
|
||||
return {
|
||||
"tok2vec": test_tok2vec(),
|
||||
"exclusive_classes": False,
|
||||
"nO": 13,
|
||||
}
|
||||
|
||||
def get_all_params(model):
|
||||
params = []
|
||||
for node in model.walk():
|
||||
|
@ -35,50 +55,34 @@ def get_gradient(model, Y):
|
|||
raise ValueError(f"Could not get gradient for type {type(Y)}")
|
||||
|
||||
|
||||
def get_tok2vec_kwargs():
|
||||
# This actually creates models, so seems best to put it in a function.
|
||||
return {
|
||||
"embed": MultiHashEmbed(
|
||||
width=32,
|
||||
rows=500,
|
||||
also_embed_subwords=True,
|
||||
also_use_static_vectors=False
|
||||
),
|
||||
"encode": MaxoutWindowEncoder(
|
||||
width=32,
|
||||
depth=2,
|
||||
maxout_pieces=2,
|
||||
window_size=1,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_tok2vec():
|
||||
return build_Tok2Vec_model(**TOK2VEC_KWARGS)
|
||||
|
||||
|
||||
TOK2VEC_KWARGS = {
|
||||
"width": 96,
|
||||
"embed_size": 2000,
|
||||
"subword_features": True,
|
||||
"char_embed": False,
|
||||
"conv_depth": 4,
|
||||
"bilstm_depth": 0,
|
||||
"maxout_pieces": 4,
|
||||
"window_size": 1,
|
||||
"dropout": 0.1,
|
||||
"nM": 0,
|
||||
"nC": 0,
|
||||
"pretrained_vectors": None,
|
||||
}
|
||||
|
||||
TEXTCAT_KWARGS = {
|
||||
"width": 64,
|
||||
"embed_size": 2000,
|
||||
"pretrained_vectors": None,
|
||||
"exclusive_classes": False,
|
||||
"ngram_size": 1,
|
||||
"window_size": 1,
|
||||
"conv_depth": 2,
|
||||
"dropout": None,
|
||||
"nO": 7,
|
||||
}
|
||||
|
||||
TEXTCAT_CNN_KWARGS = {
|
||||
"tok2vec": test_tok2vec(),
|
||||
"exclusive_classes": False,
|
||||
"nO": 13,
|
||||
}
|
||||
return build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seed,model_func,kwargs",
|
||||
[
|
||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS),
|
||||
(0, build_text_classifier, TEXTCAT_KWARGS),
|
||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS),
|
||||
(0, build_Tok2Vec_model, get_tok2vec_kwargs()),
|
||||
(0, build_text_classifier, get_textcat_kwargs()),
|
||||
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()),
|
||||
],
|
||||
)
|
||||
def test_models_initialize_consistently(seed, model_func, kwargs):
|
||||
|
@ -96,9 +100,9 @@ def test_models_initialize_consistently(seed, model_func, kwargs):
|
|||
@pytest.mark.parametrize(
|
||||
"seed,model_func,kwargs,get_X",
|
||||
[
|
||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
||||
(0, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
||||
(0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||
(0, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||
],
|
||||
)
|
||||
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
||||
|
@ -131,9 +135,9 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
|||
@pytest.mark.parametrize(
|
||||
"seed,dropout,model_func,kwargs,get_X",
|
||||
[
|
||||
(0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
||||
(0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
||||
(0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
||||
(0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||
(0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||
(0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||
],
|
||||
)
|
||||
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
||||
|
|
Loading…
Reference in New Issue
Block a user