mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-15 10:42:34 +03:00
black
This commit is contained in:
parent
bc2e082192
commit
1507e357dc
|
@ -73,9 +73,7 @@ def test_label_smoothing():
|
|||
tagger_no_ls = nlp.add_pipe(
|
||||
"tagger", "no_label_smoothing", config=dict(label_smoothing=0.0)
|
||||
)
|
||||
tagger_ls = nlp.add_pipe(
|
||||
"tagger", "label_smoothing"
|
||||
)
|
||||
tagger_ls = nlp.add_pipe("tagger", "label_smoothing")
|
||||
train_examples = []
|
||||
losses = {}
|
||||
for tag in TAGS:
|
||||
|
@ -88,8 +86,8 @@ def test_label_smoothing():
|
|||
tag_scores, bp_tag_scores = tagger_ls.model.begin_update(
|
||||
[eg.predicted for eg in train_examples]
|
||||
)
|
||||
no_ls_grads= tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
ls_grads= tagger_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
no_ls_grads = tagger_no_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
ls_grads = tagger_ls.get_loss(train_examples, tag_scores)[1][0]
|
||||
assert_array_almost_equal((ls_grads - no_ls_grads)[0], [0.05, -0.025, -0.025])
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user