mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
Update tests
This commit is contained in:
parent
07b47eaac8
commit
f0cf4a2dca
|
@ -41,7 +41,7 @@ factory = "tagger"
|
|||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model:width}
|
||||
"""
|
||||
|
||||
|
@ -71,7 +71,7 @@ def my_parser():
|
|||
tok2vec = build_Tok2Vec_model(
|
||||
MultiHashEmbed(
|
||||
width=321,
|
||||
embed_size=5432,
|
||||
rows=5432,
|
||||
also_embed_subwords=True,
|
||||
also_use_static_vectors=False
|
||||
),
|
||||
|
|
|
@ -1,7 +1,8 @@
|
|||
import pytest
|
||||
|
||||
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
||||
from spacy.ml.models.tok2vec import MultiHashEmbed, MaxoutWindowEncoder
|
||||
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
||||
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
@ -60,26 +61,26 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|||
|
||||
|
||||
# fmt: off
|
||||
@pytest.mark.xfail(reason="TODO: Update for new signature")
|
||||
@pytest.mark.parametrize(
|
||||
"tok2vec_config",
|
||||
"width,embed_arch,embed_config,encode_arch,encode_config",
|
||||
[
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
||||
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
def test_tok2vec_configs(tok2vec_config):
|
||||
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
|
||||
embed_config["width"] = width
|
||||
encode_config["width"] = width
|
||||
docs = get_batch(3)
|
||||
tok2vec = build_Tok2Vec_model_from_old_args(**tok2vec_config)
|
||||
tok2vec = build_Tok2Vec_model(
|
||||
embed_arch(**embed_config),
|
||||
encode_arch(**encode_config)
|
||||
)
|
||||
tok2vec.initialize(docs)
|
||||
vectors, backprop = tok2vec.begin_update(docs)
|
||||
assert len(vectors) == len(docs)
|
||||
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])
|
||||
assert vectors[0].shape == (len(docs[0]), width)
|
||||
backprop(vectors)
|
||||
|
|
Loading…
Reference in New Issue
Block a user