mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Convert values to numpy for label smoothing tests (#12472)
This commit is contained in:
parent
ce258670b7
commit
140d53649d
|
@ -1,6 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
from numpy.testing import assert_equal, assert_almost_equal
|
from numpy.testing import assert_equal, assert_almost_equal
|
||||||
|
|
||||||
|
from thinc.api import get_current_ops
|
||||||
|
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
@ -52,8 +54,9 @@ def test_label_smoothing():
|
||||||
tag_scores, bp_tag_scores = morph_ls.model.begin_update(
|
tag_scores, bp_tag_scores = morph_ls.model.begin_update(
|
||||||
[eg.predicted for eg in train_examples]
|
[eg.predicted for eg in train_examples]
|
||||||
)
|
)
|
||||||
no_ls_grads = morph_no_ls.get_loss(train_examples, tag_scores)[1][0]
|
ops = get_current_ops()
|
||||||
ls_grads = morph_ls.get_loss(train_examples, tag_scores)[1][0]
|
no_ls_grads = ops.to_numpy(morph_no_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||||
|
ls_grads = ops.to_numpy(morph_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||||
assert_almost_equal(ls_grads / no_ls_grads, 0.94285715)
|
assert_almost_equal(ls_grads / no_ls_grads, 0.94285715)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from spacy import util
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from thinc.api import compounding
|
from thinc.api import compounding, get_current_ops
|
||||||
|
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
@ -85,8 +85,9 @@ def test_label_smoothing():
|
||||||
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
|
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
|
||||||
[eg.predicted for eg in train_examples]
|
[eg.predicted for eg in train_examples]
|
||||||
)
|
)
|
||||||
no_ls_grads = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
|
ops = get_current_ops()
|
||||||
ls_grads = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
|
no_ls_grads = ops.to_numpy(tagger_no_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||||
|
ls_grads = ops.to_numpy(tagger_ls.get_loss(train_examples, tag_scores)[1][0])
|
||||||
assert_almost_equal(ls_grads / no_ls_grads, 0.925)
|
assert_almost_equal(ls_grads / no_ls_grads, 0.925)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user