mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
make dropout in embed layers configurable
This commit is contained in:
parent
e91485dfc4
commit
eac12cbb77
|
@ -49,13 +49,13 @@ def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO
|
|||
|
||||
@registry.architectures.register("spacy.TextCat.v1")
|
||||
def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_classes, ngram_size,
|
||||
window_size, conv_depth, nO=None):
|
||||
window_size, conv_depth, dropout, nO=None):
|
||||
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
|
||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||
lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER))
|
||||
prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX))
|
||||
suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX))
|
||||
shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE))
|
||||
lower = HashEmbed(nO=width, nV=embed_size, column=cols.index(LOWER), dropout=dropout)
|
||||
prefix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(PREFIX), dropout=dropout)
|
||||
suffix = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SUFFIX), dropout=dropout)
|
||||
shape = HashEmbed(nO=width // 2, nV=embed_size, column=cols.index(SHAPE), dropout=dropout)
|
||||
|
||||
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])
|
||||
trained_vectors = FeatureExtractor(cols) >> with_array(
|
||||
|
@ -114,7 +114,7 @@ def build_text_classifier(width, embed_size, pretrained_vectors, exclusive_class
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||
def build_text_classifier_lowdata(width, pretrained_vectors, nO=None):
|
||||
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
||||
nlp = util.load_model(pretrained_vectors)
|
||||
vectors = nlp.vocab.vectors
|
||||
vector_dim = vectors.data.shape[1]
|
||||
|
@ -129,7 +129,8 @@ def build_text_classifier_lowdata(width, pretrained_vectors, nO=None):
|
|||
>> reduce_sum()
|
||||
>> residual(Relu(width, width)) ** 2
|
||||
>> Linear(nO, width)
|
||||
>> Dropout(0.0)
|
||||
>> Logistic()
|
||||
)
|
||||
if dropout:
|
||||
model = model >> Dropout(dropout)
|
||||
model = model >> Logistic()
|
||||
return model
|
||||
|
|
|
@ -49,6 +49,7 @@ def hash_embed_cnn(
|
|||
maxout_pieces,
|
||||
window_size,
|
||||
subword_features,
|
||||
dropout,
|
||||
):
|
||||
# Does not use character embeddings: set to False by default
|
||||
return build_Tok2Vec_model(
|
||||
|
@ -63,6 +64,7 @@ def hash_embed_cnn(
|
|||
char_embed=False,
|
||||
nM=0,
|
||||
nC=0,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
|
@ -76,6 +78,7 @@ def hash_charembed_cnn(
|
|||
window_size,
|
||||
nM,
|
||||
nC,
|
||||
dropout,
|
||||
):
|
||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
||||
return build_Tok2Vec_model(
|
||||
|
@ -90,12 +93,13 @@ def hash_charembed_cnn(
|
|||
char_embed=True,
|
||||
nM=nM,
|
||||
nC=nC,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashEmbedBiLSTM.v1")
|
||||
def hash_embed_bilstm_v1(
|
||||
pretrained_vectors, width, depth, embed_size, subword_features, maxout_pieces
|
||||
pretrained_vectors, width, depth, embed_size, subword_features, maxout_pieces, dropout
|
||||
):
|
||||
# Does not use character embeddings: set to False by default
|
||||
return build_Tok2Vec_model(
|
||||
|
@ -110,12 +114,13 @@ def hash_embed_bilstm_v1(
|
|||
char_embed=False,
|
||||
nM=0,
|
||||
nC=0,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1")
|
||||
def hash_char_embed_bilstm_v1(
|
||||
pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC
|
||||
pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC, dropout
|
||||
):
|
||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
||||
return build_Tok2Vec_model(
|
||||
|
@ -130,6 +135,7 @@ def hash_char_embed_bilstm_v1(
|
|||
char_embed=True,
|
||||
nM=nM,
|
||||
nC=nC,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
|
@ -144,19 +150,19 @@ def LayerNormalizedMaxout(width, maxout_pieces):
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||
def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
|
||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
|
||||
def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix, dropout):
|
||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
|
||||
if use_subwords:
|
||||
prefix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("PREFIX"))
|
||||
suffix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SUFFIX"))
|
||||
shape = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SHAPE"))
|
||||
prefix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout)
|
||||
suffix = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout)
|
||||
shape = HashEmbed(nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout)
|
||||
|
||||
if pretrained_vectors:
|
||||
glove = StaticVectors(
|
||||
vectors=pretrained_vectors.data,
|
||||
nO=width,
|
||||
column=columns.index(ID),
|
||||
dropout=0.0,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
|
@ -164,13 +170,10 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
|
|||
embed_layer = norm
|
||||
else:
|
||||
if use_subwords and pretrained_vectors:
|
||||
nr_columns = 5
|
||||
concat_columns = glove | norm | prefix | suffix | shape
|
||||
elif use_subwords:
|
||||
nr_columns = 4
|
||||
concat_columns = norm | prefix | suffix | shape
|
||||
else:
|
||||
nr_columns = 2
|
||||
concat_columns = glove | norm
|
||||
|
||||
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
|
||||
|
@ -179,8 +182,8 @@ def MultiHashEmbed(columns, width, rows, use_subwords, pretrained_vectors, mix):
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(columns, width, rows, nM, nC, features):
|
||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"))
|
||||
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
||||
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
|
||||
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
embed_layer = chr_embed | features >> with_array(norm)
|
||||
|
@ -238,16 +241,17 @@ def build_Tok2Vec_model(
|
|||
nC,
|
||||
conv_depth,
|
||||
bilstm_depth,
|
||||
dropout,
|
||||
) -> Model:
|
||||
if char_embed:
|
||||
subword_features = False
|
||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||
norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM))
|
||||
norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM), dropout=dropout)
|
||||
if subword_features:
|
||||
prefix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(PREFIX))
|
||||
suffix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SUFFIX))
|
||||
shape = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SHAPE))
|
||||
prefix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=dropout)
|
||||
suffix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=dropout)
|
||||
shape = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=dropout)
|
||||
else:
|
||||
prefix, suffix, shape = (None, None, None)
|
||||
if pretrained_vectors is not None:
|
||||
|
@ -255,7 +259,7 @@ def build_Tok2Vec_model(
|
|||
vectors=pretrained_vectors.data,
|
||||
nO=width,
|
||||
column=cols.index(ID),
|
||||
dropout=0.0,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if subword_features:
|
||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 300
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
|
@ -11,3 +11,4 @@ window_size = 1
|
|||
maxout_pieces = 3
|
||||
nM = 64
|
||||
nC = 8
|
||||
dropout = null
|
|
@ -13,3 +13,4 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
|
|
@ -13,3 +13,4 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 2
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 7000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
|
|
@ -10,3 +10,4 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
|
|
@ -11,3 +11,4 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
|
|
@ -7,3 +7,4 @@ conv_depth = 2
|
|||
embed_size = 2000
|
||||
window_size = 1
|
||||
ngram_size = 1
|
||||
dropout = null
|
|
@ -7,3 +7,4 @@ embed_size = 2000
|
|||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
|
@ -123,9 +123,9 @@ def test_overfitting_IO():
|
|||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False},
|
||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True},
|
||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2, "dropout": None},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1, "dropout": None},
|
||||
{"@architectures": "spacy.TextCat.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3, "dropout": None},
|
||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": True},
|
||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": default_tok2vec(), "exclusive_classes": False},
|
||||
],
|
||||
|
|
|
@ -24,6 +24,7 @@ window_size = 1
|
|||
embed_size = 2000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
||||
[nlp.pipeline.tagger]
|
||||
factory = "tagger"
|
||||
|
@ -53,6 +54,7 @@ embed_size = 5555
|
|||
window_size = 1
|
||||
maxout_pieces = 7
|
||||
subword_features = false
|
||||
dropout = null
|
||||
"""
|
||||
|
||||
|
||||
|
@ -70,6 +72,7 @@ def my_parser():
|
|||
nC=8,
|
||||
conv_depth=2,
|
||||
bilstm_depth=0,
|
||||
dropout=None,
|
||||
)
|
||||
parser = build_tb_parser_model(
|
||||
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
||||
|
|
|
@ -15,7 +15,7 @@ def test_empty_doc():
|
|||
vocab = Vocab()
|
||||
doc = Doc(vocab, words=[])
|
||||
# TODO: fix tok2vec arguments
|
||||
tok2vec = build_Tok2Vec_model(width, embed_size)
|
||||
tok2vec = build_Tok2Vec_model(width, embed_size, dropout=None)
|
||||
vectors, backprop = tok2vec.begin_update([doc])
|
||||
assert len(vectors) == 1
|
||||
assert vectors[0].shape == (0, width)
|
||||
|
@ -38,6 +38,7 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|||
char_embed=False,
|
||||
nM=64,
|
||||
nC=8,
|
||||
dropout=None,
|
||||
)
|
||||
tok2vec.initialize()
|
||||
vectors, backprop = tok2vec.begin_update(batch)
|
||||
|
@ -50,14 +51,14 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|||
@pytest.mark.parametrize(
|
||||
"tok2vec_config",
|
||||
[
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
|
|
Loading…
Reference in New Issue
Block a user