Tidy up and fix types

This commit is contained in:
Ines Montani 2023-02-09 10:20:29 +01:00
parent dbfc9f688f
commit 5f049b2de6
6 changed files with 81 additions and 95 deletions

View File

@ -94,7 +94,7 @@ def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
opt = opt.replace("--", "")
if "." not in opt:
if is_cli:
raise radicli.CliParseError(f"unrecognized argument: {orig_opt}")
raise radicli.CliParserError(f"unrecognized argument: {orig_opt}")
else:
msg.fail(f"{err}: can't override top-level sections", exits=1)
if "=" in opt: # we have --opt=value

View File

@ -53,8 +53,9 @@ def assemble_cli(
with nlp.select_pipes(disable=[*sourced]):
nlp.initialize()
msg.good("Initialized pipeline")
if output_path is not None:
msg.divider("Serializing to disk")
if output_path is not None and not output_path.exists():
if not output_path.exists():
output_path.mkdir(parents=True)
msg.good(f"Created output directory: {output_path}")
nlp.to_disk(output_path)

View File

@ -1,5 +1,4 @@
from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal
from enum import Enum
from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal, cast
from pathlib import Path
from wasabi import Printer
import srsly
@ -19,6 +18,8 @@ from ..training.converters import conllu_to_docs
# matched by file extension and content. To add a converter, add a new
# entry to this dict with the file extension mapped to the converter function
# imported from /converters.
ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
FileTypesType = Literal["json", "spacy"]
CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
"conllubio": conllu_to_docs,
@ -28,19 +29,11 @@ CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
"iob": iob_to_docs,
"json": json_to_docs,
}
AUTO = "auto"
ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
AUTO: ConvertersType = "auto"
# File types that can be written to stdout
FILE_TYPES_STDOUT = ("json",)
class FileTypes(str, Enum):
json = "json"
spacy = "spacy"
@cli.command(
"convert",
# fmt: off
@ -61,7 +54,7 @@ class FileTypes(str, Enum):
def convert_cli(
input_path: ExistingPathOrDash,
output_dir: ExistingDirPathOrDash = "-",
file_type: Literal["json", "spacy"] = "spacy",
file_type: FileTypesType = "spacy",
n_sents: int = 1,
seg_sents: bool = False,
model: Optional[str] = None,
@ -110,7 +103,7 @@ def convert(
input_path: Path,
output_dir: Union[str, Path],
*,
file_type: str = "json",
file_type: FileTypesType = "json",
n_sents: int = 1,
seg_sents: bool = False,
model: Optional[str] = None,
@ -173,14 +166,16 @@ def convert(
msg.good(f"Generated output file ({len_docs} documents): {output_file}")
def _print_docs_to_stdout(data: Any, output_type: str) -> None:
def _print_docs_to_stdout(data: Any, output_type: FileTypesType) -> None:
if output_type == "json":
srsly.write_json("-", data)
else:
sys.stdout.buffer.write(data)
def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
def _write_docs_to_file(
data: Any, output_file: Path, output_type: FileTypesType
) -> None:
if not output_file.parent.exists():
output_file.parent.mkdir(parents=True)
if output_type == "json":
@ -190,7 +185,7 @@ def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
file_.write(data)
def autodetect_ner_format(input_data: str) -> Optional[str]:
def autodetect_ner_format(input_data: str) -> Optional[ConvertersType]:
# guess format from the first 20 lines
lines = input_data.split("\n")[:20]
format_guesses = {"ner": 0, "iob": 0}
@ -213,7 +208,7 @@ def verify_cli_args(
msg: Printer,
input_path: Path,
output_dir: Union[str, Path],
file_type: str,
file_type: FileTypesType,
converter: str,
ner_map: Optional[Path],
):
@ -236,7 +231,7 @@ def verify_cli_args(
msg.fail(f"Can't find converter for {converter}", exits=1)
def _get_converter(msg, converter, input_path: Path):
def _get_converter(msg: Printer, converter: ConvertersType, input_path: Path) -> str:
if input_path.is_dir():
if converter == AUTO:
input_locs = walk_directory(input_path, suffix=None)
@ -265,4 +260,4 @@ def _get_converter(msg, converter, input_path: Path):
"Conversion may not succeed. "
"See https://spacy.io/api/cli#convert"
)
return converter
return cast(str, converter)

View File

@ -4,7 +4,7 @@ from pathlib import Path
import logging
from typing import Optional, Tuple, Any, Dict, List
import numpy
import wasabi.tables
from wasabi import msg, row
from radicli import Arg, ExistingPath, ExistingFilePath
from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
@ -13,11 +13,9 @@ from ..training import Corpus
from ._util import cli, import_code, setup_gpu
from .. import util
_DEFAULTS = {
"n_trials": 11,
"use_gpu": -1,
"gold_preproc": False,
}
DEFAULT_N_TRIALS: int = 11
DEFAULT_USE_GPU: int = -1
DEFAULT_GOLD_PREPROC: bool = False
@cli.command(
@ -41,10 +39,10 @@ def find_threshold_cli(
pipe_name: str,
threshold_key: str,
scores_key: str,
n_trials: int = _DEFAULTS["n_trials"],
n_trials: int = DEFAULT_N_TRIALS,
code_path: Optional[ExistingFilePath] = None,
use_gpu: int = _DEFAULTS["use_gpu"],
gold_preproc: bool = _DEFAULTS["gold_preproc"],
use_gpu: int = DEFAULT_USE_GPU,
gold_preproc: bool = DEFAULT_GOLD_PREPROC,
verbose: bool = False,
):
"""
@ -83,9 +81,9 @@ def find_threshold(
threshold_key: str,
scores_key: str,
*,
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
n_trials: int = DEFAULT_N_TRIALS,
use_gpu: int = DEFAULT_USE_GPU,
gold_preproc: bool = DEFAULT_GOLD_PREPROC,
silent: bool = True,
) -> Tuple[float, float, Dict[float, float]]:
"""
@ -104,30 +102,25 @@ def find_threshold(
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
evaluated thresholds.
"""
setup_gpu(use_gpu, silent=silent)
data_path = util.ensure_path(data_path)
if not data_path.exists():
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
msg.fail("Evaluation data not found", data_path, exits=1)
nlp = util.load_model(model)
if pipe_name not in nlp.component_names:
raise AttributeError(
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
)
err = Errors.E001.format(name=pipe_name, opts=nlp.component_names)
raise AttributeError(err)
pipe = nlp.get_pipe(pipe_name)
if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045)
if type(pipe) == TextCategorizer:
wasabi.msg.warn(
msg.warn(
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
"exclusive classes. All thresholds will yield the same results."
)
if not silent:
wasabi.msg.info(
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
f"trials."
)
text = f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} trials."
msg.info(text)
# Load evaluation corpus.
corpus = Corpus(data_path, gold_preproc=gold_preproc)
dev_dataset = list(corpus(nlp))
@ -155,9 +148,9 @@ def find_threshold(
RETURNS (Dict[str, Any]): Filtered dictionary.
"""
if keys[0] not in config:
wasabi.msg.fail(
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
msg.fail(
f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
f"{list(config.keys())}",
exits=1,
)
@ -172,7 +165,7 @@ def find_threshold(
config_keys_full = ["components", pipe_name, *config_keys]
table_col_widths = (10, 10)
thresholds = numpy.linspace(0, 1, n_trials)
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
print(row(["Threshold", f"{scores_key}"], widths=table_col_widths))
for threshold in thresholds:
# Reload pipeline with overrides specifying the new threshold.
nlp = util.load_model(
@ -191,34 +184,28 @@ def find_threshold(
"cfg",
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
)
eval_scores = nlp.evaluate(dev_dataset)
if scores_key not in eval_scores:
wasabi.msg.fail(
msg.fail(
title=f"Failed to look up score `{scores_key}` in evaluation results.",
text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
f"available: {list(eval_scores.keys())}",
exits=1,
)
scores[threshold] = eval_scores[scores_key]
if not isinstance(scores[threshold], (float, int)):
wasabi.msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
f"scores.",
msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold "
f"optimization only works for numeric scores.",
exits=1,
)
print(
wasabi.row(
[round(threshold, 3), round(scores[threshold], 3)],
widths=table_col_widths,
)
)
data = [round(threshold, 3), round(scores[threshold], 3)]
print(row(data, widths=table_col_widths))
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
# If all scores are identical, emit warning.
if len(set(scores.values())) == 1:
wasabi.msg.warn(
title="All scores are identical. Verify that all settings are correct.",
msg.warn(
"All scores are identical. Verify that all settings are correct.",
text=""
if (
not isinstance(pipe, MultiLabel_TextCategorizer)
@ -229,7 +216,7 @@ def find_threshold(
else:
if not silent:
print(
f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}."
f"\nBest threshold: {round(best_threshold, ndigits=4)} with "
f"{scores_key} value of {scores[best_threshold]}."
)
return best_threshold, scores[best_threshold], scores

View File

@ -1,6 +1,6 @@
from typing import Optional, List, Tuple, Literal
from typing import Optional, List, Tuple, Literal, cast
from pathlib import Path
from wasabi import Printer, diff_strings
from wasabi import Printer, msg, diff_strings
from thinc.api import Config
import srsly
import re
@ -25,16 +25,17 @@ OptimizationsType = Literal["efficiency", "accuracy"]
class InitValues:
"""
Default values for initialization. Dedicated class to allow synchronized default values for init_config_cli() and
init_config(), i.e. initialization calls via CLI respectively Python.
Default values for initialization. Dedicated class to allow synchronized
default values for init_config_cli() and init_config(), i.e. initialization
calls via CLI respectively Python.
"""
lang = "en"
pipeline = SimpleFrozenList(["tagger", "parser", "ner"])
optimize = "efficiency"
gpu = False
pretraining = False
force_overwrite = False
lang: str = "en"
pipeline: List[str] = SimpleFrozenList(["tagger", "parser", "ner"])
optimize: OptimizationsType = "efficiency"
gpu: bool = False
pretraining: bool = False
force_overwrite: bool = False
@cli.subcommand(
@ -68,10 +69,11 @@ def init_config_cli(
DOCS: https://spacy.io/api/cli#init-config
"""
is_stdout = output_file == "-"
if not is_stdout and output_file.exists() and not force_overwrite:
msg = Printer()
if not is_stdout:
if output_file.exists() and not force_overwrite:
msg.fail(
"The provided output file already exists. To force overwriting the config file, set the --force or -F flag.",
"The provided output file already exists. To force overwriting "
"the config file, set the --force or -F flag.",
exits=1,
)
config = init_config(
@ -239,7 +241,10 @@ def init_config(
def save_config(
config: Config, output_file: Path, is_stdout: bool = False, silent: bool = False
config: Config,
output_file: Path,
is_stdout: bool = False,
silent: bool = False,
) -> None:
no_print = is_stdout or silent
msg = Printer(no_print=no_print)

View File

@ -5,15 +5,13 @@ from typing import Tuple, List, Dict, Any
import pkg_resources
import time
from pathlib import Path
import spacy
import numpy
import pytest
import srsly
from click import NoSuchOption
from packaging.specifiers import SpecifierSet
from thinc.api import Config, ConfigValidationError
import spacy
from spacy import about
from spacy.cli import info
from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
@ -745,13 +743,13 @@ def test_get_labels_from_model(factory_name, pipe_name):
def test_permitted_package_names():
# https://www.python.org/dev/peps/pep-0426/#name
assert _is_permitted_package_name("Meine_Bäume") == False
assert _is_permitted_package_name("_package") == False
assert _is_permitted_package_name("package_") == False
assert _is_permitted_package_name(".package") == False
assert _is_permitted_package_name("package.") == False
assert _is_permitted_package_name("-package") == False
assert _is_permitted_package_name("package-") == False
assert _is_permitted_package_name("Meine_Bäume") is False
assert _is_permitted_package_name("_package") is False
assert _is_permitted_package_name("package_") is False
assert _is_permitted_package_name(".package") is False
assert _is_permitted_package_name("package.") is False
assert _is_permitted_package_name("-package") is False
assert _is_permitted_package_name("package-") is False
def test_debug_data_compile_gold():