mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Bugfix textcat reproducibility on GPU (#6411)
* add seed argument to ParametricAttention layer * bump thinc to 7.4.3 * set thinc version range Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
cdca44ac11
commit
2af31a8c8d
|
@ -646,7 +646,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
SpacyVectors
|
SpacyVectors
|
||||||
>> flatten_add_lengths
|
>> flatten_add_lengths
|
||||||
>> with_getitem(0, Affine(width, pretrained_dims))
|
>> with_getitem(0, Affine(width, pretrained_dims))
|
||||||
>> ParametricAttention(width)
|
>> ParametricAttention(width, seed=100)
|
||||||
>> Pooling(sum_pool)
|
>> Pooling(sum_pool)
|
||||||
>> Residual(ReLu(width, width)) ** 2
|
>> Residual(ReLu(width, width)) ** 2
|
||||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||||
|
@ -688,7 +688,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
||||||
cnn_model = (
|
cnn_model = (
|
||||||
tok2vec
|
tok2vec
|
||||||
>> flatten_add_lengths
|
>> flatten_add_lengths
|
||||||
>> ParametricAttention(width)
|
>> ParametricAttention(width, seed=99)
|
||||||
>> Pooling(sum_pool)
|
>> Pooling(sum_pool)
|
||||||
>> Residual(zero_init(Maxout(width, width)))
|
>> Residual(zero_init(Maxout(width, width)))
|
||||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||||
|
|
|
@ -11,7 +11,6 @@ def test_issue6177():
|
||||||
# NOTE: no need to transform this code to v3 when 'master' is merged into 'develop'.
|
# NOTE: no need to transform this code to v3 when 'master' is merged into 'develop'.
|
||||||
# A similar test exists already for v3: test_issue5551
|
# A similar test exists already for v3: test_issue5551
|
||||||
# This is just a backport
|
# This is just a backport
|
||||||
|
|
||||||
results = []
|
results = []
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
|
@ -24,12 +23,15 @@ def test_issue6177():
|
||||||
nlp.add_pipe(textcat)
|
nlp.add_pipe(textcat)
|
||||||
for label in set(example[1]["cats"]):
|
for label in set(example[1]["cats"]):
|
||||||
textcat.add_label(label)
|
textcat.add_label(label)
|
||||||
nlp.begin_training()
|
# Train
|
||||||
|
optimizer = nlp.begin_training()
|
||||||
|
text, annots = example
|
||||||
|
nlp.update([text], [annots], sgd=optimizer)
|
||||||
# Store the result of each iteration
|
# Store the result of each iteration
|
||||||
result = textcat.model.predict([nlp.make_doc(example[0])])
|
result = textcat.model.predict([nlp.make_doc(text)])
|
||||||
results.append(list(result[0]))
|
results.append(list(result[0]))
|
||||||
|
|
||||||
# All results should be the same because of the fixed seed
|
# All results should be the same because of the fixed seed
|
||||||
assert len(results) == 3
|
assert len(results) == 3
|
||||||
assert results[0] == results[1]
|
assert results[0] == results[1]
|
||||||
assert results[0] == results[2]
|
assert results[0] == results[2]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user