mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-19 04:32:32 +03:00
fix types in scorer + black
This commit is contained in:
parent
015050f42c
commit
b8bdf998ad
|
@ -477,29 +477,57 @@ class Scorer:
|
||||||
score_per_type[label] = PRFScore()
|
score_per_type[label] = PRFScore()
|
||||||
# Find all instances, for all and per type
|
# Find all instances, for all and per type
|
||||||
gold_instances = set()
|
gold_instances = set()
|
||||||
gold_per_type = {label: set() for label in labels}
|
gold_per_type: Dict[str, Set] = {label: set() for label in labels}
|
||||||
for gold_cluster in gold_clusters:
|
for gold_cluster in gold_clusters:
|
||||||
for span1 in gold_cluster:
|
for span1 in gold_cluster:
|
||||||
for span2 in gold_cluster:
|
for span2 in gold_cluster:
|
||||||
# only record pairs where span1 comes before span2
|
# only record pairs where span1 comes before span2
|
||||||
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
|
if (span1.start < span2.start) or (
|
||||||
|
span1.start == span2.start and span1.end < span2.end
|
||||||
|
):
|
||||||
if include_label:
|
if include_label:
|
||||||
gold_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
|
gold_rel: Tuple = (
|
||||||
|
span1.label_,
|
||||||
|
span1.start,
|
||||||
|
span1.end - 1,
|
||||||
|
span2.label_,
|
||||||
|
span2.start,
|
||||||
|
span2.end - 1,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
gold_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
|
gold_rel = (
|
||||||
|
span1.start,
|
||||||
|
span1.end - 1,
|
||||||
|
span2.start,
|
||||||
|
span2.end - 1,
|
||||||
|
)
|
||||||
gold_instances.add(gold_rel)
|
gold_instances.add(gold_rel)
|
||||||
if span1.label_ == span2.label_:
|
if span1.label_ == span2.label_:
|
||||||
gold_per_type[span1.label_].add(gold_rel)
|
gold_per_type[span1.label_].add(gold_rel)
|
||||||
pred_instances = set()
|
pred_instances = set()
|
||||||
pred_per_type = {label: set() for label in labels}
|
pred_per_type: Dict[str, Set] = {label: set() for label in labels}
|
||||||
for pred_cluster in pred_clusters:
|
for pred_cluster in pred_clusters:
|
||||||
for span1 in pred_cluster:
|
for span1 in pred_cluster:
|
||||||
for span2 in pred_cluster:
|
for span2 in pred_cluster:
|
||||||
if (span1.start < span2.start) or (span1.start == span2.start and span1.end < span2.end):
|
if (span1.start < span2.start) or (
|
||||||
|
span1.start == span2.start and span1.end < span2.end
|
||||||
|
):
|
||||||
if include_label:
|
if include_label:
|
||||||
pred_rel = (span1.label_, span1.start, span1.end - 1, span2.label_, span2.start, span2.end - 1)
|
pred_rel: Tuple = (
|
||||||
|
span1.label_,
|
||||||
|
span1.start,
|
||||||
|
span1.end - 1,
|
||||||
|
span2.label_,
|
||||||
|
span2.start,
|
||||||
|
span2.end - 1,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
pred_rel = (span1.start, span1.end - 1, span2.start, span2.end - 1)
|
pred_rel = (
|
||||||
|
span1.start,
|
||||||
|
span1.end - 1,
|
||||||
|
span2.start,
|
||||||
|
span2.end - 1,
|
||||||
|
)
|
||||||
pred_instances.add(pred_rel)
|
pred_instances.add(pred_rel)
|
||||||
if span1.label_ == span2.label_:
|
if span1.label_ == span2.label_:
|
||||||
pred_per_type[span1.label_].add(pred_rel)
|
pred_per_type[span1.label_].add(pred_rel)
|
||||||
|
@ -511,11 +539,10 @@ class Scorer:
|
||||||
# Score for all labels
|
# Score for all labels
|
||||||
score.score_set(pred_instances, gold_instances)
|
score.score_set(pred_instances, gold_instances)
|
||||||
# Assemble final result
|
# Assemble final result
|
||||||
final_scores = {
|
final_scores: Dict[str, Optional[float]] = {
|
||||||
f"{attr}_p": None,
|
f"{attr}_p": None,
|
||||||
f"{attr}_r": None,
|
f"{attr}_r": None,
|
||||||
f"{attr}_f": None,
|
f"{attr}_f": None,
|
||||||
|
|
||||||
}
|
}
|
||||||
if include_label:
|
if include_label:
|
||||||
final_scores[f"{attr}_per_type"] = None
|
final_scores[f"{attr}_per_type"] = None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user