From 327f83573ac2ba8dc8e4d594c4f66019089610ea Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Tue, 6 Jul 2021 21:02:37 +1000 Subject: [PATCH] Move scores per type handling into util function (#8590) --- spacy/cli/evaluate.py | 57 ++++++++++++++++++++++++++----------------- 1 file changed, 34 insertions(+), 23 deletions(-) diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index c563f24d3..35915096e 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Dict +from typing import Optional, List, Dict, Any, Union from wasabi import Printer from pathlib import Path import re @@ -60,8 +60,8 @@ def evaluate( displacy_path: Optional[Path] = None, displacy_limit: int = 25, silent: bool = True, - spans_key="sc", -) -> Scorer: + spans_key: str = "sc", +) -> Dict[str, Any]: msg = Printer(no_print=silent, pretty=not silent) fix_random_seed() setup_gpu(use_gpu) @@ -112,7 +112,37 @@ def evaluate( data[re.sub(r"[\s/]", "_", key.lower())] = scores[key] msg.table(results, title="Results") + data = handle_scores_per_type(scores, data, spans_key=spans_key, silent=silent) + if displacy_path: + factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] + docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit])) + render_deps = "parser" in factory_names + render_ents = "ner" in factory_names + render_parses( + docs, + displacy_path, + model_name=model, + limit=displacy_limit, + deps=render_deps, + ents=render_ents, + ) + msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path) + + if output_path is not None: + srsly.write_json(output_path, data) + msg.good(f"Saved results to {output_path}") + return data + + +def handle_scores_per_type( + scores: Union[Scorer, Dict[str, Any]], + data: Dict[str, Any] = {}, + *, + spans_key: str = "sc", + silent: bool = False, +) -> Dict[str, Any]: + msg = Printer(no_print=silent, pretty=not silent) if "morph_per_feat" in scores: if scores["morph_per_feat"]: print_prf_per_type(msg, scores["morph_per_feat"], "MORPH", "feat") @@ -139,26 +169,7 @@ def evaluate( if scores["cats_auc_per_type"]: print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"]) data["cats_auc_per_type"] = scores["cats_auc_per_type"] - - if displacy_path: - factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] - docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit])) - render_deps = "parser" in factory_names - render_ents = "ner" in factory_names - render_parses( - docs, - displacy_path, - model_name=model, - limit=displacy_limit, - deps=render_deps, - ents=render_ents, - ) - msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path) - - if output_path is not None: - srsly.write_json(output_path, data) - msg.good(f"Saved results to {output_path}") - return data + return scores def render_parses(