mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Convert values to numpy for label smoothing tests (#12472)
This commit is contained in:
parent
ce258670b7
commit
140d53649d
|
@ -1,6 +1,8 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal, assert_almost_equal
|
||||
|
||||
from thinc.api import get_current_ops
|
||||
|
||||
from spacy import util
|
||||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
|
@ -52,8 +54,9 @@ def test_label_smoothing():
|
|||
tag_scores, bp_tag_scores = morph_ls.model.begin_update(
|
||||
[eg.predicted for eg in train_examples]
|
||||
)
|
||||
no_ls_grads = morph_no_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
ls_grads = morph_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
ops = get_current_ops()
|
||||
no_ls_grads = ops.to_numpy(morph_no_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
ls_grads = ops.to_numpy(morph_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
assert_almost_equal(ls_grads / no_ls_grads, 0.94285715)
|
||||
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from spacy import util
|
|||
from spacy.training import Example
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import Language
|
||||
from thinc.api import compounding
|
||||
from thinc.api import compounding, get_current_ops
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -85,8 +85,9 @@ def test_label_smoothing():
|
|||
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
|
||||
[eg.predicted for eg in train_examples]
|
||||
)
|
||||
no_ls_grads = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
ls_grads = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
ops = get_current_ops()
|
||||
no_ls_grads = ops.to_numpy(tagger_no_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
ls_grads = ops.to_numpy(tagger_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||
assert_almost_equal(ls_grads / no_ls_grads, 0.925)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user