mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
fix types in scorer + black
This commit is contained in:
parent
015050f42c
commit
b8bdf998ad
|
@ -477,29 +477,57 @@ class Scorer:
|
|||
score_per_type[label] = PRFScore()
|
||||
# Find all instances, for all and per type
|
||||
gold_instances = set()
|
||||
gold_per_type = {label: set() for label in labels}
|
||||
gold_per_type: Dict[str, Set] = {label: set() for label in labels}
|
||||
for gold_cluster in gold_clusters:
|
||||
for span1 in gold_cluster:
|
||||
for span2 in gold_cluster:
|
||||
# only record pairs where span1 comes before span2
|
||||
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
|
||||
if (span1.start < span2.start) or (
|
||||
span1.start == span2.start and span1.end < span2.end
|
||||
):
|
||||
if include_label:
|
||||
gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
|
||||
gold_rel: Tuple = (
|
||||
span1.label_,
|
||||
span1.start,
|
||||
span1.end - 1,
|
||||
span2.label_,
|
||||
span2.start,
|
||||
span2.end - 1,
|
||||
)
|
||||
else:
|
||||
gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
|
||||
gold_rel = (
|
||||
span1.start,
|
||||
span1.end - 1,
|
||||
span2.start,
|
||||
span2.end - 1,
|
||||
)
|
||||
gold_instances.add(gold_rel)
|
||||
if span1.label_ == span2.label_:
|
||||
gold_per_type[span1.label_].add(gold_rel)
|
||||
pred_instances = set()
|
||||
pred_per_type = {label: set() for label in labels}
|
||||
pred_per_type: Dict[str, Set] = {label: set() for label in labels}
|
||||
for pred_cluster in pred_clusters:
|
||||
for span1 in pred_cluster:
|
||||
for span2 in pred_cluster:
|
||||
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
|
||||
if (span1.start < span2.start) or (
|
||||
span1.start == span2.start and span1.end < span2.end
|
||||
):
|
||||
if include_label:
|
||||
pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
|
||||
pred_rel: Tuple = (
|
||||
span1.label_,
|
||||
span1.start,
|
||||
span1.end - 1,
|
||||
span2.label_,
|
||||
span2.start,
|
||||
span2.end - 1,
|
||||
)
|
||||
else:
|
||||
pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
|
||||
pred_rel = (
|
||||
span1.start,
|
||||
span1.end - 1,
|
||||
span2.start,
|
||||
span2.end - 1,
|
||||
)
|
||||
pred_instances.add(pred_rel)
|
||||
if span1.label_ == span2.label_:
|
||||
pred_per_type[span1.label_].add(pred_rel)
|
||||
|
@ -511,11 +539,10 @@ class Scorer:
|
|||
# Score for all labels
|
||||
score.score_set(pred_instances, gold_instances)
|
||||
# Assemble final result
|
||||
final_scores = {
|
||||
final_scores: Dict[str, Optional[float]] = {
|
||||
f"{attr}_p": None,
|
||||
f"{attr}_r": None,
|
||||
f"{attr}_f": None,
|
||||
|
||||
}
|
||||
if include_label:
|
||||
final_scores[f"{attr}_per_type"] = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user