mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Bugfix textcat reproducibility on GPU (#6411)
* add seed argument to ParametricAttention layer * bump thinc to 7.4.3 * set thinc version range Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
cdca44ac11
commit
2af31a8c8d
|
@ -646,7 +646,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
SpacyVectors
|
||||
>> flatten_add_lengths
|
||||
>> with_getitem(0, Affine(width, pretrained_dims))
|
||||
>> ParametricAttention(width)
|
||||
>> ParametricAttention(width, seed=100)
|
||||
>> Pooling(sum_pool)
|
||||
>> Residual(ReLu(width, width)) ** 2
|
||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||
|
@ -688,7 +688,7 @@ def build_text_classifier(nr_class, width=64, **cfg):
|
|||
cnn_model = (
|
||||
tok2vec
|
||||
>> flatten_add_lengths
|
||||
>> ParametricAttention(width)
|
||||
>> ParametricAttention(width, seed=99)
|
||||
>> Pooling(sum_pool)
|
||||
>> Residual(zero_init(Maxout(width, width)))
|
||||
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
|
||||
|
|
|
@ -11,7 +11,6 @@ def test_issue6177():
|
|||
# NOTE: no need to transform this code to v3 when 'master' is merged into 'develop'.
|
||||
# A similar test exists already for v3: test_issue5551
|
||||
# This is just a backport
|
||||
|
||||
results = []
|
||||
for i in range(3):
|
||||
fix_random_seed(0)
|
||||
|
@ -24,12 +23,15 @@ def test_issue6177():
|
|||
nlp.add_pipe(textcat)
|
||||
for label in set(example[1]["cats"]):
|
||||
textcat.add_label(label)
|
||||
nlp.begin_training()
|
||||
# Train
|
||||
optimizer = nlp.begin_training()
|
||||
text, annots = example
|
||||
nlp.update([text], [annots], sgd=optimizer)
|
||||
# Store the result of each iteration
|
||||
result = textcat.model.predict([nlp.make_doc(example[0])])
|
||||
result = textcat.model.predict([nlp.make_doc(text)])
|
||||
results.append(list(result[0]))
|
||||
|
||||
# All results should be the same because of the fixed seed
|
||||
assert len(results) == 3
|
||||
assert results[0] == results[1]
|
||||
assert results[0] == results[2]
|
||||
assert results[0] == results[2]
|
||||
|
|
Loading…
Reference in New Issue
Block a user