mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-15 02:32:37 +03:00
change test to check difference in distributions
This commit is contained in:
parent
3b9e147b28
commit
b0929271a8
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from numpy.testing import assert_equal
|
||||
from numpy.testing import assert_equal, assert_array_almost_equal
|
||||
from spacy.attrs import TAG
|
||||
|
||||
from spacy import util
|
||||
|
@ -68,7 +68,6 @@ PARTIAL_DATA = [
|
|||
|
||||
|
||||
def test_label_smoothing():
|
||||
util.fix_random_seed()
|
||||
nlp = Language()
|
||||
tagger_no_ls = nlp.add_pipe(
|
||||
"tagger", "no_label_smoothing", config=dict(label_smoothing=False)
|
||||
|
@ -83,14 +82,14 @@ def test_label_smoothing():
|
|||
tagger_ls.add_label(tag)
|
||||
for t in TRAIN_DATA:
|
||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||
for i in range(5):
|
||||
losses = {}
|
||||
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||
assert losses == {
|
||||
"no_label_smoothing": 1.4892945885658264,
|
||||
"label_smoothing": 1.1432453989982605,
|
||||
}
|
||||
|
||||
nlp.initialize(get_examples=lambda: train_examples)
|
||||
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
|
||||
[eg.predicted for eg in train_examples]
|
||||
)
|
||||
no_ls_probs = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
ls_probs = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
assert_array_almost_equal((ls_probs - no_ls_probs)[0], [0.05, -0.025, -0.025])
|
||||
|
||||
|
||||
def test_no_label():
|
||||
|
|
Loading…
Reference in New Issue
Block a user