mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Update tests
This commit is contained in:
parent
c35d6282fc
commit
20e9098e3f
|
@ -68,18 +68,18 @@ dropout = null
|
||||||
@registry.architectures.register("my_test_parser")
|
@registry.architectures.register("my_test_parser")
|
||||||
def my_parser():
|
def my_parser():
|
||||||
tok2vec = build_Tok2Vec_model(
|
tok2vec = build_Tok2Vec_model(
|
||||||
width=321,
|
MultiHashEmbed(
|
||||||
embed_size=5432,
|
width=321,
|
||||||
pretrained_vectors=None,
|
embed_size=5432,
|
||||||
window_size=3,
|
also_embed_subwords=True,
|
||||||
maxout_pieces=4,
|
also_use_static_vectors=False
|
||||||
subword_features=True,
|
),
|
||||||
char_embed=True,
|
MaxoutWindowEncoder(
|
||||||
nM=64,
|
width=321,
|
||||||
nC=8,
|
window_size=3,
|
||||||
conv_depth=2,
|
maxout_pieces=4,
|
||||||
bilstm_depth=0,
|
depth=2
|
||||||
dropout=None,
|
)
|
||||||
)
|
)
|
||||||
parser = build_tb_parser_model(
|
parser = build_tb_parser_model(
|
||||||
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
||||||
|
|
|
@ -5,12 +5,32 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
||||||
from numpy.testing import assert_array_equal
|
from numpy.testing import assert_array_equal
|
||||||
import numpy
|
import numpy
|
||||||
|
|
||||||
from spacy.ml.models import build_Tok2Vec_model
|
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||||
|
|
||||||
|
|
||||||
|
def get_textcat_kwargs():
|
||||||
|
return {
|
||||||
|
"width": 64,
|
||||||
|
"embed_size": 2000,
|
||||||
|
"pretrained_vectors": None,
|
||||||
|
"exclusive_classes": False,
|
||||||
|
"ngram_size": 1,
|
||||||
|
"window_size": 1,
|
||||||
|
"conv_depth": 2,
|
||||||
|
"dropout": None,
|
||||||
|
"nO": 7,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_textcat_cnn_kwargs():
|
||||||
|
return {
|
||||||
|
"tok2vec": test_tok2vec(),
|
||||||
|
"exclusive_classes": False,
|
||||||
|
"nO": 13,
|
||||||
|
}
|
||||||
|
|
||||||
def get_all_params(model):
|
def get_all_params(model):
|
||||||
params = []
|
params = []
|
||||||
for node in model.walk():
|
for node in model.walk():
|
||||||
|
@ -35,50 +55,34 @@ def get_gradient(model, Y):
|
||||||
raise ValueError(f"Could not get gradient for type {type(Y)}")
|
raise ValueError(f"Could not get gradient for type {type(Y)}")
|
||||||
|
|
||||||
|
|
||||||
|
def get_tok2vec_kwargs():
|
||||||
|
# This actually creates models, so seems best to put it in a function.
|
||||||
|
return {
|
||||||
|
"embed": MultiHashEmbed(
|
||||||
|
width=32,
|
||||||
|
rows=500,
|
||||||
|
also_embed_subwords=True,
|
||||||
|
also_use_static_vectors=False
|
||||||
|
),
|
||||||
|
"encode": MaxoutWindowEncoder(
|
||||||
|
width=32,
|
||||||
|
depth=2,
|
||||||
|
maxout_pieces=2,
|
||||||
|
window_size=1,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_tok2vec():
|
def test_tok2vec():
|
||||||
return build_Tok2Vec_model(**TOK2VEC_KWARGS)
|
return build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||||
|
|
||||||
|
|
||||||
TOK2VEC_KWARGS = {
|
|
||||||
"width": 96,
|
|
||||||
"embed_size": 2000,
|
|
||||||
"subword_features": True,
|
|
||||||
"char_embed": False,
|
|
||||||
"conv_depth": 4,
|
|
||||||
"bilstm_depth": 0,
|
|
||||||
"maxout_pieces": 4,
|
|
||||||
"window_size": 1,
|
|
||||||
"dropout": 0.1,
|
|
||||||
"nM": 0,
|
|
||||||
"nC": 0,
|
|
||||||
"pretrained_vectors": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
TEXTCAT_KWARGS = {
|
|
||||||
"width": 64,
|
|
||||||
"embed_size": 2000,
|
|
||||||
"pretrained_vectors": None,
|
|
||||||
"exclusive_classes": False,
|
|
||||||
"ngram_size": 1,
|
|
||||||
"window_size": 1,
|
|
||||||
"conv_depth": 2,
|
|
||||||
"dropout": None,
|
|
||||||
"nO": 7,
|
|
||||||
}
|
|
||||||
|
|
||||||
TEXTCAT_CNN_KWARGS = {
|
|
||||||
"tok2vec": test_tok2vec(),
|
|
||||||
"exclusive_classes": False,
|
|
||||||
"nO": 13,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seed,model_func,kwargs",
|
"seed,model_func,kwargs",
|
||||||
[
|
[
|
||||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS),
|
(0, build_Tok2Vec_model, get_tok2vec_kwargs()),
|
||||||
(0, build_text_classifier, TEXTCAT_KWARGS),
|
(0, build_text_classifier, get_textcat_kwargs()),
|
||||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS),
|
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_models_initialize_consistently(seed, model_func, kwargs):
|
def test_models_initialize_consistently(seed, model_func, kwargs):
|
||||||
|
@ -96,9 +100,9 @@ def test_models_initialize_consistently(seed, model_func, kwargs):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seed,model_func,kwargs,get_X",
|
"seed,model_func,kwargs,get_X",
|
||||||
[
|
[
|
||||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
(0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||||
(0, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
(0, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
||||||
|
@ -131,9 +135,9 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"seed,dropout,model_func,kwargs,get_X",
|
"seed,dropout,model_func,kwargs,get_X",
|
||||||
[
|
[
|
||||||
(0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
(0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||||
(0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
(0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||||
(0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
(0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user