fix types in scorer + black

This commit is contained in:
svlandeg 2022-05-25 13:12:37 +02:00
parent 015050f42c
commit b8bdf998ad

View File

@ -477,29 +477,57 @@ class Scorer:
score_per_type[label] = PRFScore()
# Find all instances, for all and per type
gold_instances = set()
gold_per_type = {label: set() for label in labels}
gold_per_type: Dict[str, Set] = {label: set() for label in labels}
for gold_cluster in gold_clusters:
for span1 in gold_cluster:
for span2 in gold_cluster:
# only record pairs where span1 comes before span2
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
if (span1.start < span2.start) or (
span1.start == span2.start and span1.end < span2.end
):
if include_label:
gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
gold_rel: Tuple = (
span1.label_,
span1.start,
span1.end - 1,
span2.label_,
span2.start,
span2.end - 1,
)
else:
gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
gold_rel = (
span1.start,
span1.end - 1,
span2.start,
span2.end - 1,
)
gold_instances.add(gold_rel)
if span1.label_ == span2.label_:
gold_per_type[span1.label_].add(gold_rel)
pred_instances = set()
pred_per_type = {label: set() for label in labels}
pred_per_type: Dict[str, Set] = {label: set() for label in labels}
for pred_cluster in pred_clusters:
for span1 in pred_cluster:
for span2 in pred_cluster:
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
if (span1.start < span2.start) or (
span1.start == span2.start and span1.end < span2.end
):
if include_label:
pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
pred_rel: Tuple = (
span1.label_,
span1.start,
span1.end - 1,
span2.label_,
span2.start,
span2.end - 1,
)
else:
pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
pred_rel = (
span1.start,
span1.end - 1,
span2.start,
span2.end - 1,
)
pred_instances.add(pred_rel)
if span1.label_ == span2.label_:
pred_per_type[span1.label_].add(pred_rel)
@ -511,11 +539,10 @@ class Scorer:
# Score for all labels
score.score_set(pred_instances, gold_instances)
# Assemble final result
final_scores = {
final_scores: Dict[str, Optional[float]] = {
f"{attr}_p": None,
f"{attr}_r": None,
f"{attr}_f": None,
}
if include_label:
final_scores[f"{attr}_per_type"] = None