fix types in scorer + black

This commit is contained in:
svlandeg 2022-05-25 13:12:37 +02:00
parent 015050f42c
commit b8bdf998ad

View File

@ -477,29 +477,57 @@ class Scorer:
score_per_type[label] = PRFScore() score_per_type[label] = PRFScore()
# Find all instances, for all and per type # Find all instances, for all and per type
gold_instances = set() gold_instances = set()
gold_per_type = {label: set() for label in labels} gold_per_type: Dict[str, Set] = {label: set() for label in labels}
for gold_cluster in gold_clusters: for gold_cluster in gold_clusters:
for span1 in gold_cluster: for span1 in gold_cluster:
for span2 in gold_cluster: for span2 in gold_cluster:
# only record pairs where span1 comes before span2 # only record pairs where span1 comes before span2
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end): if (span1.start < span2.start) or (
span1.start == span2.start and span1.end < span2.end
):
if include_label: if include_label:
gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1) gold_rel: Tuple = (
span1.label_,
span1.start,
span1.end - 1,
span2.label_,
span2.start,
span2.end - 1,
)
else: else:
gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1) gold_rel = (
span1.start,
span1.end - 1,
span2.start,
span2.end - 1,
)
gold_instances.add(gold_rel) gold_instances.add(gold_rel)
if span1.label_ == span2.label_: if span1.label_ == span2.label_:
gold_per_type[span1.label_].add(gold_rel) gold_per_type[span1.label_].add(gold_rel)
pred_instances = set() pred_instances = set()
pred_per_type = {label: set() for label in labels} pred_per_type: Dict[str, Set] = {label: set() for label in labels}
for pred_cluster in pred_clusters: for pred_cluster in pred_clusters:
for span1 in pred_cluster: for span1 in pred_cluster:
for span2 in pred_cluster: for span2 in pred_cluster:
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end): if (span1.start < span2.start) or (
span1.start == span2.start and span1.end < span2.end
):
if include_label: if include_label:
pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1) pred_rel: Tuple = (
span1.label_,
span1.start,
span1.end - 1,
span2.label_,
span2.start,
span2.end - 1,
)
else: else:
pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1) pred_rel = (
span1.start,
span1.end - 1,
span2.start,
span2.end - 1,
)
pred_instances.add(pred_rel) pred_instances.add(pred_rel)
if span1.label_ == span2.label_: if span1.label_ == span2.label_:
pred_per_type[span1.label_].add(pred_rel) pred_per_type[span1.label_].add(pred_rel)
@ -511,11 +539,10 @@ class Scorer:
# Score for all labels # Score for all labels
score.score_set(pred_instances, gold_instances) score.score_set(pred_instances, gold_instances)
# Assemble final result # Assemble final result
final_scores = { final_scores: Dict[str, Optional[float]] = {
f"{attr}_p": None, f"{attr}_p": None,
f"{attr}_r": None, f"{attr}_r": None,
f"{attr}_f": None, f"{attr}_f": None,
} }
if include_label: if include_label:
final_scores[f"{attr}_per_type"] = None final_scores[f"{attr}_per_type"] = None