Move scores per type handling into util function (#8590)

This commit is contained in:
Ines Montani 2021-07-06 21:02:37 +10:00 committed by GitHub
parent 5fd0b5207e
commit 327f83573a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,4 +1,4 @@
from typing import Optional, List, Dict from typing import Optional, List, Dict, Any, Union
from wasabi import Printer from wasabi import Printer
from pathlib import Path from pathlib import Path
import re import re
@ -60,8 +60,8 @@ def evaluate(
displacy_path: Optional[Path] = None, displacy_path: Optional[Path] = None,
displacy_limit: int = 25, displacy_limit: int = 25,
silent: bool = True, silent: bool = True,
spans_key="sc", spans_key: str = "sc",
) -> Scorer: ) -> Dict[str, Any]:
msg = Printer(no_print=silent, pretty=not silent) msg = Printer(no_print=silent, pretty=not silent)
fix_random_seed() fix_random_seed()
setup_gpu(use_gpu) setup_gpu(use_gpu)
@ -112,7 +112,37 @@ def evaluate(
data[re.sub(r"[\s/]", "_", key.lower())] = scores[key] data[re.sub(r"[\s/]", "_", key.lower())] = scores[key]
msg.table(results, title="Results") msg.table(results, title="Results")
data = handle_scores_per_type(scores, data, spans_key=spans_key, silent=silent)
if displacy_path:
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit]))
render_deps = "parser" in factory_names
render_ents = "ner" in factory_names
render_parses(
docs,
displacy_path,
model_name=model,
limit=displacy_limit,
deps=render_deps,
ents=render_ents,
)
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
if output_path is not None:
srsly.write_json(output_path, data)
msg.good(f"Saved results to {output_path}")
return data
def handle_scores_per_type(
scores: Union[Scorer, Dict[str, Any]],
data: Dict[str, Any] = {},
*,
spans_key: str = "sc",
silent: bool = False,
) -> Dict[str, Any]:
msg = Printer(no_print=silent, pretty=not silent)
if "morph_per_feat" in scores: if "morph_per_feat" in scores:
if scores["morph_per_feat"]: if scores["morph_per_feat"]:
print_prf_per_type(msg, scores["morph_per_feat"], "MORPH", "feat") print_prf_per_type(msg, scores["morph_per_feat"], "MORPH", "feat")
@ -139,26 +169,7 @@ def evaluate(
if scores["cats_auc_per_type"]: if scores["cats_auc_per_type"]:
print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"]) print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"])
data["cats_auc_per_type"] = scores["cats_auc_per_type"] data["cats_auc_per_type"] = scores["cats_auc_per_type"]
return scores
if displacy_path:
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit]))
render_deps = "parser" in factory_names
render_ents = "ner" in factory_names
render_parses(
docs,
displacy_path,
model_name=model,
limit=displacy_limit,
deps=render_deps,
ents=render_ents,
)
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
if output_path is not None:
srsly.write_json(output_path, data)
msg.good(f"Saved results to {output_path}")
return data
def render_parses( def render_parses(