This commit is contained in:
vinit 2023-02-20 17:01:47 +05:30
parent bc2e082192
commit 1507e357dc

View File

@ -73,9 +73,7 @@ def test_label_smoothing():
tagger_no_ls = nlp.add_pipe( tagger_no_ls = nlp.add_pipe(
"tagger", "no_label_smoothing", config=dict(label_smoothing=0.0) "tagger", "no_label_smoothing", config=dict(label_smoothing=0.0)
) )
tagger_ls = nlp.add_pipe( tagger_ls = nlp.add_pipe("tagger", "label_smoothing")
"tagger", "label_smoothing"
)
train_examples = [] train_examples = []
losses = {} losses = {}
for tag in TAGS: for tag in TAGS:
@ -88,8 +86,8 @@ def test_label_smoothing():
tag_scores, bp_tag_scores = tagger_ls.model.begin_update( tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
[eg.predicted for eg in train_examples] [eg.predicted for eg in train_examples]
) )
no_ls_grads= tagger_no_ls.get_loss(train_examples, tag_scores)[1][0] no_ls_grads = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
ls_grads= tagger_ls.get_loss(train_examples, tag_scores)[1][0] ls_grads = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
assert_array_almost_equal((ls_grads - no_ls_grads)[0], [0.05, -0.025, -0.025]) assert_array_almost_equal((ls_grads - no_ls_grads)[0], [0.05, -0.025, -0.025])