Output more stats in evaluate

This commit is contained in:
Ines Montani 2020-06-28 15:34:28 +02:00
parent 90b7fa8fed
commit dbfa292ed3

View File

@ -1,4 +1,4 @@
from typing import Optional, List
from typing import Optional, List, Dict
from timeit import default_timer as timer
from wasabi import Printer
from pathlib import Path
@ -89,8 +89,20 @@ def evaluate(
"Sent R": f"{scorer.sent_r:.2f}",
"Sent F": f"{scorer.sent_f:.2f}",
}
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
msg.table(results, title="Results")
if scorer.ents_per_type:
data["ents_per_type"] = scorer.ents_per_type
print_ents_per_type(msg, scorer.ents_per_type)
if scorer.textcats_f_per_cat:
data["textcats_f_per_cat"] = scorer.textcats_f_per_cat
print_textcats_f_per_cat(msg, scorer.textcats_f_per_cat)
if scorer.textcats_auc_per_cat:
data["textcats_auc_per_cat"] = scorer.textcats_auc_per_cat
print_textcats_auc_per_cat(msg, scorer.textcats_auc_per_cat)
if displacy_path:
docs = [ex.predicted for ex in dev_dataset]
render_deps = "parser" in nlp.meta.get("pipeline", [])
@ -105,7 +117,6 @@ def evaluate(
)
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
if output_path is not None:
srsly.write_json(output_path, data)
msg.good(f"Saved results to {output_path}")
@ -131,3 +142,40 @@ def render_parses(
)
with (output_path / "parses.html").open("w", encoding="utf8") as file_:
file_.write(html)
def print_ents_per_type(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
data = [
(k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
for k, v in scores.items()
]
msg.table(
data,
header=("", "P", "R", "F"),
aligns=("l", "r", "r", "r"),
title="NER (per type)",
)
def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None:
data = [
(k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}")
for k, v in scores.items()
]
msg.table(
data,
header=("", "P", "R", "F"),
aligns=("l", "r", "r", "r"),
title="Textcat F (per type)",
)
def print_textcats_auc_per_cat(
msg: Printer, scores: Dict[str, Dict[str, float]]
) -> None:
msg.table(
[(k, f"{v['roc_auc_score']:.2f}") for k, v in scores.items()],
header=("", "ROC AUC"),
aligns=("l", "r"),
title="Textcat ROC AUC (per label)",
)