mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Move scores per type handling into util function (#8590)
This commit is contained in:
parent
5fd0b5207e
commit
327f83573a
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, List, Dict
|
||||
from typing import Optional, List, Dict, Any, Union
|
||||
from wasabi import Printer
|
||||
from pathlib import Path
|
||||
import re
|
||||
|
@ -60,8 +60,8 @@ def evaluate(
|
|||
displacy_path: Optional[Path] = None,
|
||||
displacy_limit: int = 25,
|
||||
silent: bool = True,
|
||||
spans_key="sc",
|
||||
) -> Scorer:
|
||||
spans_key: str = "sc",
|
||||
) -> Dict[str, Any]:
|
||||
msg = Printer(no_print=silent, pretty=not silent)
|
||||
fix_random_seed()
|
||||
setup_gpu(use_gpu)
|
||||
|
@ -112,7 +112,37 @@ def evaluate(
|
|||
data[re.sub(r"[\s/]", "_", key.lower())] = scores[key]
|
||||
|
||||
msg.table(results, title="Results")
|
||||
data = handle_scores_per_type(scores, data, spans_key=spans_key, silent=silent)
|
||||
|
||||
if displacy_path:
|
||||
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
||||
docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit]))
|
||||
render_deps = "parser" in factory_names
|
||||
render_ents = "ner" in factory_names
|
||||
render_parses(
|
||||
docs,
|
||||
displacy_path,
|
||||
model_name=model,
|
||||
limit=displacy_limit,
|
||||
deps=render_deps,
|
||||
ents=render_ents,
|
||||
)
|
||||
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
||||
|
||||
if output_path is not None:
|
||||
srsly.write_json(output_path, data)
|
||||
msg.good(f"Saved results to {output_path}")
|
||||
return data
|
||||
|
||||
|
||||
def handle_scores_per_type(
|
||||
scores: Union[Scorer, Dict[str, Any]],
|
||||
data: Dict[str, Any] = {},
|
||||
*,
|
||||
spans_key: str = "sc",
|
||||
silent: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
msg = Printer(no_print=silent, pretty=not silent)
|
||||
if "morph_per_feat" in scores:
|
||||
if scores["morph_per_feat"]:
|
||||
print_prf_per_type(msg, scores["morph_per_feat"], "MORPH", "feat")
|
||||
|
@ -139,26 +169,7 @@ def evaluate(
|
|||
if scores["cats_auc_per_type"]:
|
||||
print_textcats_auc_per_cat(msg, scores["cats_auc_per_type"])
|
||||
data["cats_auc_per_type"] = scores["cats_auc_per_type"]
|
||||
|
||||
if displacy_path:
|
||||
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
||||
docs = list(nlp.pipe(ex.reference.text for ex in dev_dataset[:displacy_limit]))
|
||||
render_deps = "parser" in factory_names
|
||||
render_ents = "ner" in factory_names
|
||||
render_parses(
|
||||
docs,
|
||||
displacy_path,
|
||||
model_name=model,
|
||||
limit=displacy_limit,
|
||||
deps=render_deps,
|
||||
ents=render_ents,
|
||||
)
|
||||
msg.good(f"Generated {displacy_limit} parses as HTML", displacy_path)
|
||||
|
||||
if output_path is not None:
|
||||
srsly.write_json(output_path, data)
|
||||
msg.good(f"Saved results to {output_path}")
|
||||
return data
|
||||
return scores
|
||||
|
||||
|
||||
def render_parses(
|
||||
|
|
Loading…
Reference in New Issue
Block a user