change test to check difference in distributions

This commit is contained in:
vinit 2023-02-17 18:17:58 +05:30
parent 3b9e147b28
commit b0929271a8

View File

@ -1,5 +1,5 @@
import pytest
from numpy.testing import assert_equal
from numpy.testing import assert_equal, assert_array_almost_equal
from spacy.attrs import TAG
from spacy import util
@ -68,7 +68,6 @@ PARTIAL_DATA = [
def test_label_smoothing():
util.fix_random_seed()
nlp = Language()
tagger_no_ls = nlp.add_pipe(
"tagger", "no_label_smoothing", config=dict(label_smoothing=False)
@ -83,14 +82,14 @@ def test_label_smoothing():
tagger_ls.add_label(tag)
for t in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
optimizer = nlp.initialize(get_examples=lambda: train_examples)
for i in range(5):
losses = {}
nlp.update(train_examples, sgd=optimizer, losses=losses)
assert losses == {
"no_label_smoothing": 1.4892945885658264,
"label_smoothing": 1.1432453989982605,
}
nlp.initialize(get_examples=lambda: train_examples)
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
[eg.predicted for eg in train_examples]
)
no_ls_probs = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
ls_probs = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
assert_array_almost_equal((ls_probs - no_ls_probs)[0], [0.05, -0.025, -0.025])
def test_no_label():