mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Output more stats in evaluate
This commit is contained in:
parent
90b7fa8fed
commit
dbfa292ed3
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict
|
||||
from timeit import default_timer as timer
|
||||
from wasabi import Printer
|
||||
from pathlib import Path
|
||||
|
@ -89,8 +89,20 @@ def evaluate(
|
|||
"Sent R": f"{scorer.sent_r:.2f}",
|
||||
"Sent F": f"{scorer.sent_f:.2f}",
|
||||
}
|
||||
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
||||
|
||||
msg.table(results, title="Results")
|
||||
|
||||
if scorer.ents_per_type:
|
||||
data["ents_per_type"] = scorer.ents_per_type
|
||||
print_ents_per_type(msg, scorer.ents_per_type)
|
||||
if scorer.textcats_f_per_cat:
|
||||
data["textcats_f_per_cat"] = scorer.textcats_f_per_cat
|
||||
print_textcats_f_per_cat(msg, scorer.textcats_f_per_cat)
|
||||
if scorer.textcats_auc_per_cat:
|
||||
data["textcats_auc_per_cat"] = scorer.textcats_auc_per_cat
|
||||
print_textcats_auc_per_cat(msg, scorer.textcats_auc_per_cat)
|
||||
|
||||
if displacy_path:
|
||||
docs = [ex.predicted for ex in dev_dataset]
|
||||
render_deps = "parser" in nlp.meta.get("pipeline", [])
|
||||
|
@ -105,7 +117,6 @@ def evaluate(
|
|||
)
|
||||
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
||||
|
||||
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
||||
if output_path is not None:
|
||||
srsly.write_json(output_path, data)
|
||||
msg.good(f"Saved results to {output_path}")
|
||||
|
@ -131,3 +142,40 @@ def render_parses(
|
|||
)
|
||||
with (output_path / "parses.html").open("w", encoding="utf8") as file_:
|
||||
file_.write(html)
|
||||
|
||||
|
||||
def print_ents_per_type(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
|
||||
data = [
|
||||
(k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
|
||||
for k, v in scores.items()
|
||||
]
|
||||
msg.table(
|
||||
data,
|
||||
header=("", "P", "R", "F"),
|
||||
aligns=("l", "r", "r", "r"),
|
||||
title="NER (per type)",
|
||||
)
|
||||
|
||||
|
||||
def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
|
||||
data = [
|
||||
(k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
|
||||
for k, v in scores.items()
|
||||
]
|
||||
msg.table(
|
||||
data,
|
||||
header=("", "P", "R", "F"),
|
||||
aligns=("l", "r", "r", "r"),
|
||||
title="Textcat F (per type)",
|
||||
)
|
||||
|
||||
|
||||
def print_textcats_auc_per_cat(
|
||||
msg: Printer, scores: Dict[str, Dict[str, float]]
|
||||
) -> None:
|
||||
msg.table(
|
||||
[(k, f"{v['roc_auc_score']:.2f}") for k, v in scores.items()],
|
||||
header=("", "ROC AUC"),
|
||||
aligns=("l", "r"),
|
||||
title="Textcat ROC AUC (per label)",
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue
Block a user