Fixing reproducible training (#5735)

* Add initial reproducibility tests

* failing test for default_text_classifier (WIP)

* track trouble to underlying tok2vec layer

* add regression test for Issue 5551

* tests go green with https://github.com/explosion/thinc/pull/359

* update test

* adding fixed seeds to HashEmbed layers, seems to fix the reproducility issue

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
Sofie Van Landeghem 2020-07-09 19:39:31 +02:00 committed by GitHub
parent 1827f22f56
commit c1ea55307b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 196 additions and 9 deletions

View File

@ -87,16 +87,16 @@ def build_text_classifier(
cols = [ORTH, LOWER, PREFIX, SUFFIX, SHAPE, ID]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
lower = HashEmbed(
nO=width, nV=embed_size, column=cols.index(LOWER), dropout=dropout
nO=width, nV=embed_size, column=cols.index(LOWER), dropout=dropout, seed=10
)
prefix = HashEmbed(
nO=width // 2, nV=embed_size, column=cols.index(PREFIX), dropout=dropout
nO=width // 2, nV=embed_size, column=cols.index(PREFIX), dropout=dropout, seed=11
)
suffix = HashEmbed(
nO=width // 2, nV=embed_size, column=cols.index(SUFFIX), dropout=dropout
nO=width // 2, nV=embed_size, column=cols.index(SUFFIX), dropout=dropout, seed=12
)
shape = HashEmbed(
nO=width // 2, nV=embed_size, column=cols.index(SHAPE), dropout=dropout
nO=width // 2, nV=embed_size, column=cols.index(SHAPE), dropout=dropout, seed=13
)
width_nI = sum(layer.get_dim("nO") for layer in [lower, prefix, suffix, shape])

View File

@ -154,16 +154,16 @@ def LayerNormalizedMaxout(width, maxout_pieces):
def MultiHashEmbed(
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
):
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6)
if use_subwords:
prefix = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout
nO=width, nV=rows // 2, column=columns.index("PREFIX"), dropout=dropout, seed=7
)
suffix = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout
nO=width, nV=rows // 2, column=columns.index("SUFFIX"), dropout=dropout, seed=8
)
shape = HashEmbed(
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout
nO=width, nV=rows // 2, column=columns.index("SHAPE"), dropout=dropout, seed=9
)
if pretrained_vectors:
@ -192,7 +192,7 @@ def MultiHashEmbed(
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout)
norm = HashEmbed(nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5)
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
with Model.define_operators({">>": chain, "|": concatenate}):
embed_layer = chr_embed | features >> with_array(norm)

View File

@ -0,0 +1,31 @@
from spacy.lang.en import English
from spacy.util import fix_random_seed
def test_issue5551():
"""Test that after fixing the random seed, the results of the pipeline are truly identical"""
component = "textcat"
pipe_cfg = {"exclusive_classes": False}
results = []
for i in range(3):
fix_random_seed(0)
nlp = English()
example = (
"Once hot, form ping-pong-ball-sized balls of the mixture, each weighing roughly 25 g.",
{"cats": {"Labe1": 1.0, "Label2": 0.0, "Label3": 0.0}},
)
nlp.add_pipe(nlp.create_pipe(component, config=pipe_cfg), last=True)
pipe = nlp.get_pipe(component)
for label in set(example[1]["cats"]):
pipe.add_label(label)
nlp.begin_training(component_cfg={component: pipe_cfg})
# Store the result of each iteration
result = pipe.model.predict([nlp.make_doc(example[0])])
results.append(list(result[0]))
# All results should be the same because of the fixed seed
assert len(results) == 3
assert results[0] == results[1]
assert results[0] == results[2]

156
spacy/tests/test_models.py Normal file
View File

@ -0,0 +1,156 @@
from typing import List
import pytest
from thinc.api import fix_random_seed, Adam, set_dropout_rate
from numpy.testing import assert_array_equal
import numpy
from spacy.ml.models import build_Tok2Vec_model
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
from spacy.lang.en import English
from spacy.lang.en.examples import sentences as EN_SENTENCES
def get_all_params(model):
params = []
for node in model.walk():
for name in node.param_names:
params.append(node.get_param(name).ravel())
return node.ops.xp.concatenate(params)
def get_docs():
nlp = English()
return list(nlp.pipe(EN_SENTENCES + [" ".join(EN_SENTENCES)]))
def get_gradient(model, Y):
if isinstance(Y, model.ops.xp.ndarray):
dY = model.ops.alloc(Y.shape, dtype=Y.dtype)
dY += model.ops.xp.random.uniform(-1.0, 1.0, Y.shape)
return dY
elif isinstance(Y, List):
return [get_gradient(model, y) for y in Y]
else:
raise ValueError(f"Could not compare type {type(Y)}")
def default_tok2vec():
return build_Tok2Vec_model(**TOK2VEC_KWARGS)
TOK2VEC_KWARGS = {
"width": 96,
"embed_size": 2000,
"subword_features": True,
"char_embed": False,
"conv_depth": 4,
"bilstm_depth": 0,
"maxout_pieces": 4,
"window_size": 1,
"dropout": 0.1,
"nM": 0,
"nC": 0,
"pretrained_vectors": None,
}
TEXTCAT_KWARGS = {
"width": 64,
"embed_size": 2000,
"pretrained_vectors": None,
"exclusive_classes": False,
"ngram_size": 1,
"window_size": 1,
"conv_depth": 2,
"dropout": None,
"nO": 7
}
TEXTCAT_CNN_KWARGS = {
"tok2vec": default_tok2vec(),
"exclusive_classes": False,
"nO": 13,
}
@pytest.mark.parametrize(
"seed,model_func,kwargs",
[
(0, build_Tok2Vec_model, TOK2VEC_KWARGS),
(0, build_text_classifier, TEXTCAT_KWARGS),
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS),
],
)
def test_models_initialize_consistently(seed, model_func, kwargs):
fix_random_seed(seed)
model1 = model_func(**kwargs)
model1.initialize()
fix_random_seed(seed)
model2 = model_func(**kwargs)
model2.initialize()
params1 = get_all_params(model1)
params2 = get_all_params(model2)
assert_array_equal(params1, params2)
@pytest.mark.parametrize(
"seed,model_func,kwargs,get_X",
[
(0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
(0, build_text_classifier, TEXTCAT_KWARGS, get_docs),
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
],
)
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
fix_random_seed(seed)
model1 = model_func(**kwargs).initialize()
Y1 = model1.predict(get_X())
fix_random_seed(seed)
model2 = model_func(**kwargs).initialize()
Y2 = model2.predict(get_X())
if model1.has_ref("tok2vec"):
tok2vec1 = model1.get_ref("tok2vec").predict(get_X())
tok2vec2 = model2.get_ref("tok2vec").predict(get_X())
for i in range(len(tok2vec1)):
for j in range(len(tok2vec1[i])):
assert_array_equal(numpy.asarray(tok2vec1[i][j]), numpy.asarray(tok2vec2[i][j]))
if isinstance(Y1, numpy.ndarray):
assert_array_equal(Y1, Y2)
elif isinstance(Y1, List):
assert len(Y1) == len(Y2)
for y1, y2 in zip(Y1, Y2):
assert_array_equal(y1, y2)
else:
raise ValueError(f"Could not compare type {type(Y1)}")
@pytest.mark.parametrize(
"seed,dropout,model_func,kwargs,get_X",
[
(0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
(0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs),
(0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
],
)
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
def get_updated_model():
fix_random_seed(seed)
optimizer = Adam(0.001)
model = model_func(**kwargs).initialize()
initial_params = get_all_params(model)
set_dropout_rate(model, dropout)
for _ in range(5):
Y, get_dX = model.begin_update(get_X())
dY = get_gradient(model, Y)
_ = get_dX(dY)
model.finish_update(optimizer)
updated_params = get_all_params(model)
with pytest.raises(AssertionError):
assert_array_equal(initial_params, updated_params)
return model
model1 = get_updated_model()
model2 = get_updated_model()
assert_array_equal(get_all_params(model1), get_all_params(model2))