Remove all uses of threshold for multi_label=False

This commit is contained in:
Adriane Boyd 2022-10-25 15:22:09 +02:00
parent d9a6734526
commit 2c18864bc3

View File

@ -507,20 +507,18 @@ class Scorer:
# Get the highest-scoring for each. # Get the highest-scoring for each.
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1]) pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1]) gold_label, gold_score = max(gold_cats.items(), key=lambda it: it[1])
if pred_label == gold_label and pred_score >= threshold: if pred_label == gold_label:
f_per_type[pred_label].tp += 1 f_per_type[pred_label].tp += 1
else: else:
f_per_type[gold_label].fn += 1 f_per_type[gold_label].fn += 1
if pred_score >= threshold: f_per_type[pred_label].fp += 1
f_per_type[pred_label].fp += 1
elif gold_cats: elif gold_cats:
gold_label, gold_score = max(gold_cats, key=lambda it: it[1]) gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
if gold_score > 0: if gold_score > 0:
f_per_type[gold_label].fn += 1 f_per_type[gold_label].fn += 1
elif pred_cats: elif pred_cats:
pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1]) pred_label, pred_score = max(pred_cats.items(), key=lambda it: it[1])
if pred_score >= threshold: f_per_type[pred_label].fp += 1
f_per_type[pred_label].fp += 1
micro_prf = PRFScore() micro_prf = PRFScore()
for label_prf in f_per_type.values(): for label_prf in f_per_type.values():
micro_prf.tp += label_prf.tp micro_prf.tp += label_prf.tp