mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-28 06:31:12 +03:00 
			
		
		
		
	Output more stats in evaluate
This commit is contained in:
		
							parent
							
								
									90b7fa8fed
								
							
						
					
					
						commit
						dbfa292ed3
					
				|  | @ -1,4 +1,4 @@ | |||
| from typing import Optional, List | ||||
| from typing import Optional, List, Dict | ||||
| from timeit import default_timer as timer | ||||
| from wasabi import Printer | ||||
| from pathlib import Path | ||||
|  | @ -89,8 +89,20 @@ def evaluate( | |||
|         "Sent R": f"{scorer.sent_r:.2f}", | ||||
|         "Sent F": f"{scorer.sent_f:.2f}", | ||||
|     } | ||||
|     data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} | ||||
| 
 | ||||
|     msg.table(results, title="Results") | ||||
| 
 | ||||
|     if scorer.ents_per_type: | ||||
|         data["ents_per_type"] = scorer.ents_per_type | ||||
|         print_ents_per_type(msg, scorer.ents_per_type) | ||||
|     if scorer.textcats_f_per_cat: | ||||
|         data["textcats_f_per_cat"] = scorer.textcats_f_per_cat | ||||
|         print_textcats_f_per_cat(msg, scorer.textcats_f_per_cat) | ||||
|     if scorer.textcats_auc_per_cat: | ||||
|         data["textcats_auc_per_cat"] = scorer.textcats_auc_per_cat | ||||
|         print_textcats_auc_per_cat(msg, scorer.textcats_auc_per_cat) | ||||
| 
 | ||||
|     if displacy_path: | ||||
|         docs = [ex.predicted for ex in dev_dataset] | ||||
|         render_deps = "parser" in nlp.meta.get("pipeline", []) | ||||
|  | @ -105,7 +117,6 @@ def evaluate( | |||
|         ) | ||||
|         msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path) | ||||
| 
 | ||||
|     data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()} | ||||
|     if output_path is not None: | ||||
|         srsly.write_json(output_path, data) | ||||
|         msg.good(f"Saved results to {output_path}") | ||||
|  | @ -131,3 +142,40 @@ def render_parses( | |||
|         ) | ||||
|         with (output_path / "parses.html").open("w", encoding="utf8") as file_: | ||||
|             file_.write(html) | ||||
| 
 | ||||
| 
 | ||||
| def print_ents_per_type(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None: | ||||
|     data = [ | ||||
|         (k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}") | ||||
|         for k, v in scores.items() | ||||
|     ] | ||||
|     msg.table( | ||||
|         data, | ||||
|         header=("", "P", "R", "F"), | ||||
|         aligns=("l", "r", "r", "r"), | ||||
|         title="NER (per type)", | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def print_textcats_f_per_cat(msg: Printer, scores: Dict[str, Dict[str, float]]) -> None: | ||||
|     data = [ | ||||
|         (k, f"{v['p']:.2f}", f"{v['r']:.2f}", f"{v['f']:.2f}") | ||||
|         for k, v in scores.items() | ||||
|     ] | ||||
|     msg.table( | ||||
|         data, | ||||
|         header=("", "P", "R", "F"), | ||||
|         aligns=("l", "r", "r", "r"), | ||||
|         title="Textcat F (per type)", | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def print_textcats_auc_per_cat( | ||||
|     msg: Printer, scores: Dict[str, Dict[str, float]] | ||||
| ) -> None: | ||||
|     msg.table( | ||||
|         [(k, f"{v['roc_auc_score']:.2f}") for k, v in scores.items()], | ||||
|         header=("", "ROC AUC"), | ||||
|         aligns=("l", "r"), | ||||
|         title="Textcat ROC AUC (per label)", | ||||
|     ) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user