From 5f049b2de65c1608f067a94d5b80ee276ab22054 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 9 Feb 2023 10:20:29 +0100 Subject: [PATCH] Tidy up and fix types --- spacy/cli/_util.py | 2 +- spacy/cli/assemble.py | 11 +++--- spacy/cli/convert.py | 33 +++++++---------- spacy/cli/find_threshold.py | 73 +++++++++++++++---------------------- spacy/cli/init_config.py | 39 +++++++++++--------- spacy/tests/test_cli.py | 18 ++++----- 6 files changed, 81 insertions(+), 95 deletions(-) diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 2167558d4..4077d2545 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -94,7 +94,7 @@ def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]: opt = opt.replace("--", "") if "." not in opt: if is_cli: - raise radicli.CliParseError(f"unrecognized argument: {orig_opt}") + raise radicli.CliParserError(f"unrecognized argument: {orig_opt}") else: msg.fail(f"{err}: can't override top-level sections", exits=1) if "=" in opt: # we have --opt=value diff --git a/spacy/cli/assemble.py b/spacy/cli/assemble.py index 679c5cc10..a3e6bd14c 100644 --- a/spacy/cli/assemble.py +++ b/spacy/cli/assemble.py @@ -53,8 +53,9 @@ def assemble_cli( with nlp.select_pipes(disable=[*sourced]): nlp.initialize() msg.good("Initialized pipeline") - msg.divider("Serializing to disk") - if output_path is not None and not output_path.exists(): - output_path.mkdir(parents=True) - msg.good(f"Created output directory: {output_path}") - nlp.to_disk(output_path) + if output_path is not None: + msg.divider("Serializing to disk") + if not output_path.exists(): + output_path.mkdir(parents=True) + msg.good(f"Created output directory: {output_path}") + nlp.to_disk(output_path) diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py index cf70860bb..eb79d11d8 100644 --- a/spacy/cli/convert.py +++ b/spacy/cli/convert.py @@ -1,5 +1,4 @@ -from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal -from enum import Enum +from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal, cast from pathlib import Path from wasabi import Printer import srsly @@ -19,6 +18,8 @@ from ..training.converters import conllu_to_docs # matched by file extension and content. To add a converter, add a new # entry to this dict with the file extension mapped to the converter function # imported from /converters. +ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"] +FileTypesType = Literal["json", "spacy"] CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = { "conllubio": conllu_to_docs, @@ -28,19 +29,11 @@ CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = { "iob": iob_to_docs, "json": json_to_docs, } -AUTO = "auto" -ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"] - - +AUTO: ConvertersType = "auto" # File types that can be written to stdout FILE_TYPES_STDOUT = ("json",) -class FileTypes(str, Enum): - json = "json" - spacy = "spacy" - - @cli.command( "convert", # fmt: off @@ -61,7 +54,7 @@ class FileTypes(str, Enum): def convert_cli( input_path: ExistingPathOrDash, output_dir: ExistingDirPathOrDash = "-", - file_type: Literal["json", "spacy"] = "spacy", + file_type: FileTypesType = "spacy", n_sents: int = 1, seg_sents: bool = False, model: Optional[str] = None, @@ -110,7 +103,7 @@ def convert( input_path: Path, output_dir: Union[str, Path], *, - file_type: str = "json", + file_type: FileTypesType = "json", n_sents: int = 1, seg_sents: bool = False, model: Optional[str] = None, @@ -173,14 +166,16 @@ def convert( msg.good(f"Generated output file ({len_docs} documents): {output_file}") -def _print_docs_to_stdout(data: Any, output_type: str) -> None: +def _print_docs_to_stdout(data: Any, output_type: FileTypesType) -> None: if output_type == "json": srsly.write_json("-", data) else: sys.stdout.buffer.write(data) -def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None: +def _write_docs_to_file( + data: Any, output_file: Path, output_type: FileTypesType +) -> None: if not output_file.parent.exists(): output_file.parent.mkdir(parents=True) if output_type == "json": @@ -190,7 +185,7 @@ def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None: file_.write(data) -def autodetect_ner_format(input_data: str) -> Optional[str]: +def autodetect_ner_format(input_data: str) -> Optional[ConvertersType]: # guess format from the first 20 lines lines = input_data.split("\n")[:20] format_guesses = {"ner": 0, "iob": 0} @@ -213,7 +208,7 @@ def verify_cli_args( msg: Printer, input_path: Path, output_dir: Union[str, Path], - file_type: str, + file_type: FileTypesType, converter: str, ner_map: Optional[Path], ): @@ -236,7 +231,7 @@ def verify_cli_args( msg.fail(f"Can't find converter for {converter}", exits=1) -def _get_converter(msg, converter, input_path: Path): +def _get_converter(msg: Printer, converter: ConvertersType, input_path: Path) -> str: if input_path.is_dir(): if converter == AUTO: input_locs = walk_directory(input_path, suffix=None) @@ -265,4 +260,4 @@ def _get_converter(msg, converter, input_path: Path): "Conversion may not succeed. " "See https://spacy.io/api/cli#convert" ) - return converter + return cast(str, converter) diff --git a/spacy/cli/find_threshold.py b/spacy/cli/find_threshold.py index a7614f878..e3bdd4c1a 100644 --- a/spacy/cli/find_threshold.py +++ b/spacy/cli/find_threshold.py @@ -4,7 +4,7 @@ from pathlib import Path import logging from typing import Optional, Tuple, Any, Dict, List import numpy -import wasabi.tables +from wasabi import msg, row from radicli import Arg, ExistingPath, ExistingFilePath from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer @@ -13,11 +13,9 @@ from ..training import Corpus from ._util import cli, import_code, setup_gpu from .. import util -_DEFAULTS = { - "n_trials": 11, - "use_gpu": -1, - "gold_preproc": False, -} +DEFAULT_N_TRIALS: int = 11 +DEFAULT_USE_GPU: int = -1 +DEFAULT_GOLD_PREPROC: bool = False @cli.command( @@ -41,10 +39,10 @@ def find_threshold_cli( pipe_name: str, threshold_key: str, scores_key: str, - n_trials: int = _DEFAULTS["n_trials"], + n_trials: int = DEFAULT_N_TRIALS, code_path: Optional[ExistingFilePath] = None, - use_gpu: int = _DEFAULTS["use_gpu"], - gold_preproc: bool = _DEFAULTS["gold_preproc"], + use_gpu: int = DEFAULT_USE_GPU, + gold_preproc: bool = DEFAULT_GOLD_PREPROC, verbose: bool = False, ): """ @@ -83,9 +81,9 @@ def find_threshold( threshold_key: str, scores_key: str, *, - n_trials: int = _DEFAULTS["n_trials"], # type: ignore - use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore - gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore + n_trials: int = DEFAULT_N_TRIALS, + use_gpu: int = DEFAULT_USE_GPU, + gold_preproc: bool = DEFAULT_GOLD_PREPROC, silent: bool = True, ) -> Tuple[float, float, Dict[float, float]]: """ @@ -104,30 +102,25 @@ def find_threshold( RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all evaluated thresholds. """ - setup_gpu(use_gpu, silent=silent) data_path = util.ensure_path(data_path) if not data_path.exists(): - wasabi.msg.fail("Evaluation data not found", data_path, exits=1) + msg.fail("Evaluation data not found", data_path, exits=1) nlp = util.load_model(model) - if pipe_name not in nlp.component_names: - raise AttributeError( - Errors.E001.format(name=pipe_name, opts=nlp.component_names) - ) + err = Errors.E001.format(name=pipe_name, opts=nlp.component_names) + raise AttributeError(err) pipe = nlp.get_pipe(pipe_name) if not hasattr(pipe, "scorer"): raise AttributeError(Errors.E1045) if type(pipe) == TextCategorizer: - wasabi.msg.warn( + msg.warn( "The `textcat` component doesn't use a threshold as it's not applicable to the concept of " "exclusive classes. All thresholds will yield the same results." ) if not silent: - wasabi.msg.info( - title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} " - f"trials." - ) + text = f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} trials." + msg.info(text) # Load evaluation corpus. corpus = Corpus(data_path, gold_preproc=gold_preproc) dev_dataset = list(corpus(nlp)) @@ -155,9 +148,9 @@ def find_threshold( RETURNS (Dict[str, Any]): Filtered dictionary. """ if keys[0] not in config: - wasabi.msg.fail( - title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.", - text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: " + msg.fail( + f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.", + f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: " f"{list(config.keys())}", exits=1, ) @@ -172,7 +165,7 @@ def find_threshold( config_keys_full = ["components", pipe_name, *config_keys] table_col_widths = (10, 10) thresholds = numpy.linspace(0, 1, n_trials) - print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths)) + print(row(["Threshold", f"{scores_key}"], widths=table_col_widths)) for threshold in thresholds: # Reload pipeline with overrides specifying the new threshold. nlp = util.load_model( @@ -191,34 +184,28 @@ def find_threshold( "cfg", set_nested_item(getattr(pipe, "cfg"), config_keys, threshold), ) - eval_scores = nlp.evaluate(dev_dataset) if scores_key not in eval_scores: - wasabi.msg.fail( + msg.fail( title=f"Failed to look up score `{scores_key}` in evaluation results.", text=f"Make sure you specified the correct value for `scores_key`. The following scores are " f"available: {list(eval_scores.keys())}", exits=1, ) scores[threshold] = eval_scores[scores_key] - if not isinstance(scores[threshold], (float, int)): - wasabi.msg.fail( - f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric " - f"scores.", + msg.fail( + f"Returned score for key '{scores_key}' is not numeric. Threshold " + f"optimization only works for numeric scores.", exits=1, ) - print( - wasabi.row( - [round(threshold, 3), round(scores[threshold], 3)], - widths=table_col_widths, - ) - ) + data = [round(threshold, 3), round(scores[threshold], 3)] + print(row(data, widths=table_col_widths)) best_threshold = max(scores.keys(), key=(lambda key: scores[key])) # If all scores are identical, emit warning. if len(set(scores.values())) == 1: - wasabi.msg.warn( - title="All scores are identical. Verify that all settings are correct.", + msg.warn( + "All scores are identical. Verify that all settings are correct.", text="" if ( not isinstance(pipe, MultiLabel_TextCategorizer) @@ -229,7 +216,7 @@ def find_threshold( else: if not silent: print( - f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}." + f"\nBest threshold: {round(best_threshold, ndigits=4)} with " + f"{scores_key} value of {scores[best_threshold]}." ) - return best_threshold, scores[best_threshold], scores diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index b05e4d80a..f08c15792 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -1,6 +1,6 @@ -from typing import Optional, List, Tuple, Literal +from typing import Optional, List, Tuple, Literal, cast from pathlib import Path -from wasabi import Printer, diff_strings +from wasabi import Printer, msg, diff_strings from thinc.api import Config import srsly import re @@ -25,16 +25,17 @@ OptimizationsType = Literal["efficiency", "accuracy"] class InitValues: """ - Default values for initialization. Dedicated class to allow synchronized default values for init_config_cli() and - init_config(), i.e. initialization calls via CLI respectively Python. + Default values for initialization. Dedicated class to allow synchronized + default values for init_config_cli() and init_config(), i.e. initialization + calls via CLI respectively Python. """ - lang = "en" - pipeline = SimpleFrozenList(["tagger", "parser", "ner"]) - optimize = "efficiency" - gpu = False - pretraining = False - force_overwrite = False + lang: str = "en" + pipeline: List[str] = SimpleFrozenList(["tagger", "parser", "ner"]) + optimize: OptimizationsType = "efficiency" + gpu: bool = False + pretraining: bool = False + force_overwrite: bool = False @cli.subcommand( @@ -68,12 +69,13 @@ def init_config_cli( DOCS: https://spacy.io/api/cli#init-config """ is_stdout = output_file == "-" - if not is_stdout and output_file.exists() and not force_overwrite: - msg = Printer() - msg.fail( - "The provided output file already exists. To force overwriting the config file, set the --force or -F flag.", - exits=1, - ) + if not is_stdout: + if output_file.exists() and not force_overwrite: + msg.fail( + "The provided output file already exists. To force overwriting " + "the config file, set the --force or -F flag.", + exits=1, + ) config = init_config( lang=lang, pipeline=pipeline, @@ -239,7 +241,10 @@ def init_config( def save_config( - config: Config, output_file: Path, is_stdout: bool = False, silent: bool = False + config: Config, + output_file: Path, + is_stdout: bool = False, + silent: bool = False, ) -> None: no_print = is_stdout or silent msg = Printer(no_print=no_print) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index dc7ce46fe..7a45aa6e4 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -5,15 +5,13 @@ from typing import Tuple, List, Dict, Any import pkg_resources import time from pathlib import Path - -import spacy -import numpy import pytest import srsly from click import NoSuchOption from packaging.specifiers import SpecifierSet from thinc.api import Config, ConfigValidationError +import spacy from spacy import about from spacy.cli import info from spacy.cli._util import is_subpath_of, load_project_config, walk_directory @@ -745,13 +743,13 @@ def test_get_labels_from_model(factory_name, pipe_name): def test_permitted_package_names(): # https://www.python.org/dev/peps/pep-0426/#name - assert _is_permitted_package_name("Meine_Bäume") == False - assert _is_permitted_package_name("_package") == False - assert _is_permitted_package_name("package_") == False - assert _is_permitted_package_name(".package") == False - assert _is_permitted_package_name("package.") == False - assert _is_permitted_package_name("-package") == False - assert _is_permitted_package_name("package-") == False + assert _is_permitted_package_name("Meine_Bäume") is False + assert _is_permitted_package_name("_package") is False + assert _is_permitted_package_name("package_") is False + assert _is_permitted_package_name(".package") is False + assert _is_permitted_package_name("package.") is False + assert _is_permitted_package_name("-package") is False + assert _is_permitted_package_name("package-") is False def test_debug_data_compile_gold():