From b8bdf998ade11ba14c7bebcfe764322c6e654622 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Wed, 25 May 2022 13:12:37 +0200 Subject: [PATCH] fix types in scorer + black --- spacy/scorer.py | 47 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/spacy/scorer.py b/spacy/scorer.py index 4856bfc0d..8ee6294ad 100644 --- a/spacy/scorer.py +++ b/spacy/scorer.py @@ -477,29 +477,57 @@ class Scorer: score_per_type[label] = PRFScore() # Find all instances, for all and per type gold_instances = set() - gold_per_type = {label: set() for label in labels} + gold_per_type: Dict[str, Set] = {label: set() for label in labels} for gold_cluster in gold_clusters: for span1 in gold_cluster: for span2 in gold_cluster: # only record pairs where span1 comes before span2 - if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end): + if (span1.start < span2.start) or ( + span1.start == span2.start and span1.end < span2.end + ): if include_label: - gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1) + gold_rel: Tuple = ( + span1.label_, + span1.start, + span1.end - 1, + span2.label_, + span2.start, + span2.end - 1, + ) else: - gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1) + gold_rel = ( + span1.start, + span1.end - 1, + span2.start, + span2.end - 1, + ) gold_instances.add(gold_rel) if span1.label_ == span2.label_: gold_per_type[span1.label_].add(gold_rel) pred_instances = set() - pred_per_type = {label: set() for label in labels} + pred_per_type: Dict[str, Set] = {label: set() for label in labels} for pred_cluster in pred_clusters: for span1 in pred_cluster: for span2 in pred_cluster: - if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end): + if (span1.start < span2.start) or ( + span1.start == span2.start and span1.end < span2.end + ): if include_label: - pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1) + pred_rel: Tuple = ( + span1.label_, + span1.start, + span1.end - 1, + span2.label_, + span2.start, + span2.end - 1, + ) else: - pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1) + pred_rel = ( + span1.start, + span1.end - 1, + span2.start, + span2.end - 1, + ) pred_instances.add(pred_rel) if span1.label_ == span2.label_: pred_per_type[span1.label_].add(pred_rel) @@ -511,11 +539,10 @@ class Scorer: # Score for all labels score.score_set(pred_instances, gold_instances) # Assemble final result - final_scores = { + final_scores: Dict[str, Optional[float]] = { f"{attr}_p": None, f"{attr}_r": None, f"{attr}_f": None, - } if include_label: final_scores[f"{attr}_per_type"] = None