mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-06 21:30:22 +03:00
don't add more labels to the given set
This commit is contained in:
parent
89bee0ed77
commit
8deabcf7ef
|
@ -480,14 +480,12 @@ class Scorer:
|
||||||
f_per_type = {label: PRFScore() for label in labels}
|
f_per_type = {label: PRFScore() for label in labels}
|
||||||
auc_per_type = {label: ROCAUCScore() for label in labels}
|
auc_per_type = {label: ROCAUCScore() for label in labels}
|
||||||
labels = set(labels)
|
labels = set(labels)
|
||||||
if labels:
|
|
||||||
for eg in examples:
|
|
||||||
labels.update(eg.predicted.cats.keys())
|
|
||||||
labels.update(eg.reference.cats.keys())
|
|
||||||
for example in examples:
|
for example in examples:
|
||||||
# Through this loop, None in the gold_cats indicates missing label.
|
# Through this loop, None in the gold_cats indicates missing label.
|
||||||
pred_cats = getter(example.predicted, attr)
|
pred_cats = getter(example.predicted, attr)
|
||||||
|
pred_cats = {k: v for k, v in pred_cats.items() if k in labels}
|
||||||
gold_cats = getter(example.reference, attr)
|
gold_cats = getter(example.reference, attr)
|
||||||
|
gold_cats = {k: v for k, v in gold_cats.items() if k in labels}
|
||||||
|
|
||||||
for label in labels:
|
for label in labels:
|
||||||
pred_score = pred_cats.get(label, 0.0)
|
pred_score = pred_cats.get(label, 0.0)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user