mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-12 17:22:25 +03:00
Move scores per type handling into util function (#8590)
This commit is contained in:
parent
5fd0b5207e
commit
327f83573a
|
@ -1,4 +1,4 @@
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict, Any, Union
|
||||||
from wasabi import Printer
|
from wasabi import Printer
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import re
|
import re
|
||||||
|
@ -60,8 +60,8 @@ def evaluate(
|
||||||
displacy_path: Optional[Path] = None,
|
displacy_path: Optional[Path] = None,
|
||||||
displacy_limit: int = 25,
|
displacy_limit: int = 25,
|
||||||
silent: bool = True,
|
silent: bool = True,
|
||||||
spans_key="sc",
|
spans_key: str = "sc",
|
||||||
) -> Scorer:
|
) -> Dict[str, Any]:
|
||||||
msg = Printer(no_print=silent, pretty=not silent)
|
msg = Printer(no_print=silent, pretty=not silent)
|
||||||
fix_random_seed()
|
fix_random_seed()
|
||||||
setup_gpu(use_gpu)
|
setup_gpu(use_gpu)
|
||||||
|
@ -112,7 +112,37 @@ def evaluate(
|
||||||
data[re.sub(r"[\s/]", "_", key.lower())] = scores[key]
|
data[re.sub(r"[\s/]", "_", key.lower())] = scores[key]
|
||||||
|
|
||||||
msg.table(results, title="Results")
|
msg.table(results, title="Results")
|
||||||
|
data = handle_scores_per_type(scores, data, spans_key=spans_key, silent=silent)
|
||||||
|
|
||||||
|
if displacy_path:
|
||||||
|
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
||||||
|
docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit]))
|
||||||
|
render_deps = "parser" in factory_names
|
||||||
|
render_ents = "ner" in factory_names
|
||||||
|
render_parses(
|
||||||
|
docs,
|
||||||
|
displacy_path,
|
||||||
|
model_name=model,
|
||||||
|
limit=displacy_limit,
|
||||||
|
deps=render_deps,
|
||||||
|
ents=render_ents,
|
||||||
|
)
|
||||||
|
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
||||||
|
|
||||||
|
if output_path is not None:
|
||||||
|
srsly.write_json(output_path, data)
|
||||||
|
msg.good(f"Saved results to {output_path}")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def handle_scores_per_type(
|
||||||
|
scores: Union[Scorer, Dict[str, Any]],
|
||||||
|
data: Dict[str, Any] = {},
|
||||||
|
*,
|
||||||
|
spans_key: str = "sc",
|
||||||
|
silent: bool = False,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
msg = Printer(no_print=silent, pretty=not silent)
|
||||||
if "morph_per_feat" in scores:
|
if "morph_per_feat" in scores:
|
||||||
if scores["morph_per_feat"]:
|
if scores["morph_per_feat"]:
|
||||||
print_prf_per_type(msg, scores["morph_per_feat"], "MORPH", "feat")
|
print_prf_per_type(msg, scores["morph_per_feat"], "MORPH", "feat")
|
||||||
|
@ -139,26 +169,7 @@ def evaluate(
|
||||||
if scores["cats_auc_per_type"]:
|
if scores["cats_auc_per_type"]:
|
||||||
print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"])
|
print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"])
|
||||||
data["cats_auc_per_type"] = scores["cats_auc_per_type"]
|
data["cats_auc_per_type"] = scores["cats_auc_per_type"]
|
||||||
|
return scores
|
||||||
if displacy_path:
|
|
||||||
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
|
||||||
docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit]))
|
|
||||||
render_deps = "parser" in factory_names
|
|
||||||
render_ents = "ner" in factory_names
|
|
||||||
render_parses(
|
|
||||||
docs,
|
|
||||||
displacy_path,
|
|
||||||
model_name=model,
|
|
||||||
limit=displacy_limit,
|
|
||||||
deps=render_deps,
|
|
||||||
ents=render_ents,
|
|
||||||
)
|
|
||||||
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
|
||||||
|
|
||||||
if output_path is not None:
|
|
||||||
srsly.write_json(output_path, data)
|
|
||||||
msg.good(f"Saved results to {output_path}")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def render_parses(
|
def render_parses(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user