Update tests

This commit is contained in:
Matthew Honnibal 2020-07-28 23:06:46 +02:00
parent 00de30bcc2
commit c7d1ece3eb
2 changed files with 28 additions and 25 deletions

View File

@ -5,6 +5,7 @@ from spacy.lang.en import English
from spacy.language import Language
from spacy.util import registry, deep_merge_configs, load_model_from_config
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
from ..util import make_tempdir

View File

@ -1,6 +1,7 @@
import pytest
from spacy.ml.models.tok2vec import build_Tok2Vec_model
from spacy.ml.models.tok2vec import MultiHashEmbed, MaxoutWindowEncoder
from spacy.vocab import Vocab
from spacy.tokens import Doc
@ -13,18 +14,18 @@ def test_empty_doc():
vocab = Vocab()
doc = Doc(vocab, words=[])
tok2vec = build_Tok2Vec_model(
width,
embed_size,
pretrained_vectors=None,
conv_depth=4,
bilstm_depth=0,
window_size=1,
maxout_pieces=3,
subword_features=True,
char_embed=False,
nM=64,
nC=8,
dropout=None,
MultiHashEmbed(
width=width,
rows=embed_size,
also_use_static_vectors=False,
also_embed_subwords=True
),
MaxoutWindowEncoder(
width=width,
depth=4,
window_size=1,
maxout_pieces=3
)
)
tok2vec.initialize()
vectors, backprop = tok2vec.begin_update([doc])
@ -38,18 +39,18 @@ def test_empty_doc():
def test_tok2vec_batch_sizes(batch_size, width, embed_size):
batch = get_batch(batch_size)
tok2vec = build_Tok2Vec_model(
width,
embed_size,
pretrained_vectors=None,
conv_depth=4,
bilstm_depth=0,
window_size=1,
maxout_pieces=3,
subword_features=True,
char_embed=False,
nM=64,
nC=8,
dropout=None,
MultiHashEmbed(
width=width,
rows=embed_size,
also_use_static_vectors=False,
also_embed_subwords=True
),
MaxoutWindowEncoder(
width=width,
depth=4,
window_size=1,
maxout_pieces=3,
)
)
tok2vec.initialize()
vectors, backprop = tok2vec.begin_update(batch)
@ -59,6 +60,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
# fmt: off
@pytest.mark.xfail(reason="TODO: Update for new signature")
@pytest.mark.parametrize(
"tok2vec_config",
[
@ -75,7 +77,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
# fmt: on
def test_tok2vec_configs(tok2vec_config):
docs = get_batch(3)
tok2vec = build_Tok2Vec_model(**tok2vec_config)
tok2vec = build_Tok2Vec_model_from_old_args(**tok2vec_config)
tok2vec.initialize(docs)
vectors, backprop = tok2vec.begin_update(docs)
assert len(vectors) == len(docs)