Convert values to numpy for label smoothing tests (#12472)

This commit is contained in:
Adriane Boyd 2023-03-31 13:41:41 +02:00 committed by GitHub
parent ce258670b7
commit 140d53649d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 9 additions and 5 deletions

View File

@ -1,6 +1,8 @@
import pytest
from numpy.testing import assert_equal, assert_almost_equal
from thinc.api import get_current_ops
from spacy import util
from spacy.training import Example
from spacy.lang.en import English
@ -52,8 +54,9 @@ def test_label_smoothing():
tag_scores, bp_tag_scores = morph_ls.model.begin_update(
[eg.predicted for eg in train_examples]
)
no_ls_grads = morph_no_ls.get_loss(train_examples, tag_scores)[1][0]
ls_grads = morph_ls.get_loss(train_examples, tag_scores)[1][0]
ops = get_current_ops()
no_ls_grads = ops.to_numpy(morph_no_ls.get_loss(train_examples, tag_scores)[1][0])
ls_grads = ops.to_numpy(morph_ls.get_loss(train_examples, tag_scores)[1][0])
assert_almost_equal(ls_grads / no_ls_grads, 0.94285715)

View File

@ -6,7 +6,7 @@ from spacy import util
from spacy.training import Example
from spacy.lang.en import English
from spacy.language import Language
from thinc.api import compounding
from thinc.api import compounding, get_current_ops
from ..util import make_tempdir
@ -85,8 +85,9 @@ def test_label_smoothing():
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
[eg.predicted for eg in train_examples]
)
no_ls_grads = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
ls_grads = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
ops = get_current_ops()
no_ls_grads = ops.to_numpy(tagger_no_ls.get_loss(train_examples, tag_scores)[1][0])
ls_grads = ops.to_numpy(tagger_ls.get_loss(train_examples, tag_scores)[1][0])
assert_almost_equal(ls_grads / no_ls_grads, 0.925)