2019-10-31 17:01:15 +03:00
|
|
|
import pytest
|
|
|
|
|
2020-02-27 20:42:27 +03:00
|
|
|
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
2020-07-29 14:47:37 +03:00
|
|
|
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
|
|
|
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
2019-10-31 17:01:15 +03:00
|
|
|
from spacy.vocab import Vocab
|
|
|
|
from spacy.tokens import Doc
|
|
|
|
|
2020-03-29 20:40:36 +03:00
|
|
|
from .util import get_batch
|
2019-10-31 17:01:15 +03:00
|
|
|
|
|
|
|
|
|
|
|
def test_empty_doc():
|
|
|
|
width = 128
|
|
|
|
embed_size = 2000
|
|
|
|
vocab = Vocab()
|
|
|
|
doc = Doc(vocab, words=[])
|
2020-07-20 15:49:54 +03:00
|
|
|
tok2vec = build_Tok2Vec_model(
|
2020-07-29 00:06:46 +03:00
|
|
|
MultiHashEmbed(
|
|
|
|
width=width,
|
|
|
|
rows=embed_size,
|
|
|
|
also_use_static_vectors=False,
|
|
|
|
also_embed_subwords=True
|
|
|
|
),
|
|
|
|
MaxoutWindowEncoder(
|
|
|
|
width=width,
|
|
|
|
depth=4,
|
|
|
|
window_size=1,
|
|
|
|
maxout_pieces=3
|
|
|
|
)
|
2020-07-20 15:49:54 +03:00
|
|
|
)
|
|
|
|
tok2vec.initialize()
|
2019-10-31 17:01:15 +03:00
|
|
|
vectors, backprop = tok2vec.begin_update([doc])
|
|
|
|
assert len(vectors) == 1
|
|
|
|
assert vectors[0].shape == (0, width)
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
|
|
"batch_size,width,embed_size", [[1, 128, 2000], [2, 128, 2000], [3, 8, 63]]
|
|
|
|
)
|
|
|
|
def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|
|
|
batch = get_batch(batch_size)
|
2020-02-27 20:42:27 +03:00
|
|
|
tok2vec = build_Tok2Vec_model(
|
2020-07-29 00:06:46 +03:00
|
|
|
MultiHashEmbed(
|
|
|
|
width=width,
|
|
|
|
rows=embed_size,
|
|
|
|
also_use_static_vectors=False,
|
|
|
|
also_embed_subwords=True
|
|
|
|
),
|
|
|
|
MaxoutWindowEncoder(
|
|
|
|
width=width,
|
|
|
|
depth=4,
|
|
|
|
window_size=1,
|
|
|
|
maxout_pieces=3,
|
|
|
|
)
|
2020-02-27 20:42:27 +03:00
|
|
|
)
|
2020-01-29 19:06:46 +03:00
|
|
|
tok2vec.initialize()
|
2019-10-31 17:01:15 +03:00
|
|
|
vectors, backprop = tok2vec.begin_update(batch)
|
|
|
|
assert len(vectors) == len(batch)
|
|
|
|
for doc_vec, doc in zip(vectors, batch):
|
|
|
|
assert doc_vec.shape == (len(doc), width)
|
|
|
|
|
|
|
|
|
2020-02-27 20:42:27 +03:00
|
|
|
# fmt: off
|
2019-10-31 17:01:15 +03:00
|
|
|
@pytest.mark.parametrize(
|
2020-07-29 14:47:37 +03:00
|
|
|
"width,embed_arch,embed_config,encode_arch,encode_config",
|
2019-10-31 17:01:15 +03:00
|
|
|
[
|
2020-07-29 14:47:37 +03:00
|
|
|
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
|
|
|
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
|
|
|
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
|
|
|
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
2019-10-31 17:01:15 +03:00
|
|
|
],
|
|
|
|
)
|
2020-02-27 20:42:27 +03:00
|
|
|
# fmt: on
|
2020-07-29 14:47:37 +03:00
|
|
|
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
|
|
|
|
embed_config["width"] = width
|
|
|
|
encode_config["width"] = width
|
2019-10-31 17:01:15 +03:00
|
|
|
docs = get_batch(3)
|
2020-07-29 14:47:37 +03:00
|
|
|
tok2vec = build_Tok2Vec_model(
|
|
|
|
embed_arch(**embed_config),
|
|
|
|
encode_arch(**encode_config)
|
|
|
|
)
|
2020-03-29 20:40:36 +03:00
|
|
|
tok2vec.initialize(docs)
|
2019-10-31 17:01:15 +03:00
|
|
|
vectors, backprop = tok2vec.begin_update(docs)
|
|
|
|
assert len(vectors) == len(docs)
|
2020-07-29 14:47:37 +03:00
|
|
|
assert vectors[0].shape == (len(docs[0]), width)
|
2019-10-31 17:01:15 +03:00
|
|
|
backprop(vectors)
|