Tidy up and fix types

This commit is contained in:
Ines Montani 2023-02-09 10:20:29 +01:00
parent dbfc9f688f
commit 5f049b2de6
6 changed files with 81 additions and 95 deletions

View File

@ -94,7 +94,7 @@ def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
opt = opt.replace("--", "") opt = opt.replace("--", "")
if "." not in opt: if "." not in opt:
if is_cli: if is_cli:
raise radicli.CliParseError(f"unrecognized argument: {orig_opt}") raise radicli.CliParserError(f"unrecognized argument: {orig_opt}")
else: else:
msg.fail(f"{err}: can't override top-level sections", exits=1) msg.fail(f"{err}: can't override top-level sections", exits=1)
if "=" in opt: # we have --opt=value if "=" in opt: # we have --opt=value

View File

@ -53,8 +53,9 @@ def assemble_cli(
with nlp.select_pipes(disable=[*sourced]): with nlp.select_pipes(disable=[*sourced]):
nlp.initialize() nlp.initialize()
msg.good("Initialized pipeline") msg.good("Initialized pipeline")
if output_path is not None:
msg.divider("Serializing to disk") msg.divider("Serializing to disk")
if output_path is not None and not output_path.exists(): if not output_path.exists():
output_path.mkdir(parents=True) output_path.mkdir(parents=True)
msg.good(f"Created output directory: {output_path}") msg.good(f"Created output directory: {output_path}")
nlp.to_disk(output_path) nlp.to_disk(output_path)

View File

@ -1,5 +1,4 @@
from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal, cast
from enum import Enum
from pathlib import Path from pathlib import Path
from wasabi import Printer from wasabi import Printer
import srsly import srsly
@ -19,6 +18,8 @@ from ..training.converters import conllu_to_docs
# matched by file extension and content. To add a converter, add a new # matched by file extension and content. To add a converter, add a new
# entry to this dict with the file extension mapped to the converter function # entry to this dict with the file extension mapped to the converter function
# imported from /converters. # imported from /converters.
ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
FileTypesType = Literal["json", "spacy"]
CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = { CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
"conllubio": conllu_to_docs, "conllubio": conllu_to_docs,
@ -28,19 +29,11 @@ CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
"iob": iob_to_docs, "iob": iob_to_docs,
"json": json_to_docs, "json": json_to_docs,
} }
AUTO = "auto" AUTO: ConvertersType = "auto"
ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
# File types that can be written to stdout # File types that can be written to stdout
FILE_TYPES_STDOUT = ("json",) FILE_TYPES_STDOUT = ("json",)
class FileTypes(str, Enum):
json = "json"
spacy = "spacy"
@cli.command( @cli.command(
"convert", "convert",
# fmt: off # fmt: off
@ -61,7 +54,7 @@ class FileTypes(str, Enum):
def convert_cli( def convert_cli(
input_path: ExistingPathOrDash, input_path: ExistingPathOrDash,
output_dir: ExistingDirPathOrDash = "-", output_dir: ExistingDirPathOrDash = "-",
file_type: Literal["json", "spacy"] = "spacy", file_type: FileTypesType = "spacy",
n_sents: int = 1, n_sents: int = 1,
seg_sents: bool = False, seg_sents: bool = False,
model: Optional[str] = None, model: Optional[str] = None,
@ -110,7 +103,7 @@ def convert(
input_path: Path, input_path: Path,
output_dir: Union[str, Path], output_dir: Union[str, Path],
*, *,
file_type: str = "json", file_type: FileTypesType = "json",
n_sents: int = 1, n_sents: int = 1,
seg_sents: bool = False, seg_sents: bool = False,
model: Optional[str] = None, model: Optional[str] = None,
@ -173,14 +166,16 @@ def convert(
msg.good(f"Generated output file ({len_docs} documents): {output_file}") msg.good(f"Generated output file ({len_docs} documents): {output_file}")
def _print_docs_to_stdout(data: Any, output_type: str) -> None: def _print_docs_to_stdout(data: Any, output_type: FileTypesType) -> None:
if output_type == "json": if output_type == "json":
srsly.write_json("-", data) srsly.write_json("-", data)
else: else:
sys.stdout.buffer.write(data) sys.stdout.buffer.write(data)
def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None: def _write_docs_to_file(
data: Any, output_file: Path, output_type: FileTypesType
) -> None:
if not output_file.parent.exists(): if not output_file.parent.exists():
output_file.parent.mkdir(parents=True) output_file.parent.mkdir(parents=True)
if output_type == "json": if output_type == "json":
@ -190,7 +185,7 @@ def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
file_.write(data) file_.write(data)
def autodetect_ner_format(input_data: str) -> Optional[str]: def autodetect_ner_format(input_data: str) -> Optional[ConvertersType]:
# guess format from the first 20 lines # guess format from the first 20 lines
lines = input_data.split("\n")[:20] lines = input_data.split("\n")[:20]
format_guesses = {"ner": 0, "iob": 0} format_guesses = {"ner": 0, "iob": 0}
@ -213,7 +208,7 @@ def verify_cli_args(
msg: Printer, msg: Printer,
input_path: Path, input_path: Path,
output_dir: Union[str, Path], output_dir: Union[str, Path],
file_type: str, file_type: FileTypesType,
converter: str, converter: str,
ner_map: Optional[Path], ner_map: Optional[Path],
): ):
@ -236,7 +231,7 @@ def verify_cli_args(
msg.fail(f"Can't find converter for {converter}", exits=1) msg.fail(f"Can't find converter for {converter}", exits=1)
def _get_converter(msg, converter, input_path: Path): def _get_converter(msg: Printer, converter: ConvertersType, input_path: Path) -> str:
if input_path.is_dir(): if input_path.is_dir():
if converter == AUTO: if converter == AUTO:
input_locs = walk_directory(input_path, suffix=None) input_locs = walk_directory(input_path, suffix=None)
@ -265,4 +260,4 @@ def _get_converter(msg, converter, input_path: Path):
"Conversion may not succeed. " "Conversion may not succeed. "
"See https://spacy.io/api/cli#convert" "See https://spacy.io/api/cli#convert"
) )
return converter return cast(str, converter)

View File

@ -4,7 +4,7 @@ from pathlib import Path
import logging import logging
from typing import Optional, Tuple, Any, Dict, List from typing import Optional, Tuple, Any, Dict, List
import numpy import numpy
import wasabi.tables from wasabi import msg, row
from radicli import Arg, ExistingPath, ExistingFilePath from radicli import Arg, ExistingPath, ExistingFilePath
from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
@ -13,11 +13,9 @@ from ..training import Corpus
from ._util import cli, import_code, setup_gpu from ._util import cli, import_code, setup_gpu
from .. import util from .. import util
_DEFAULTS = { DEFAULT_N_TRIALS: int = 11
"n_trials": 11, DEFAULT_USE_GPU: int = -1
"use_gpu": -1, DEFAULT_GOLD_PREPROC: bool = False
"gold_preproc": False,
}
@cli.command( @cli.command(
@ -41,10 +39,10 @@ def find_threshold_cli(
pipe_name: str, pipe_name: str,
threshold_key: str, threshold_key: str,
scores_key: str, scores_key: str,
n_trials: int = _DEFAULTS["n_trials"], n_trials: int = DEFAULT_N_TRIALS,
code_path: Optional[ExistingFilePath] = None, code_path: Optional[ExistingFilePath] = None,
use_gpu: int = _DEFAULTS["use_gpu"], use_gpu: int = DEFAULT_USE_GPU,
gold_preproc: bool = _DEFAULTS["gold_preproc"], gold_preproc: bool = DEFAULT_GOLD_PREPROC,
verbose: bool = False, verbose: bool = False,
): ):
""" """
@ -83,9 +81,9 @@ def find_threshold(
threshold_key: str, threshold_key: str,
scores_key: str, scores_key: str,
*, *,
n_trials: int = _DEFAULTS["n_trials"], # type: ignore n_trials: int = DEFAULT_N_TRIALS,
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore use_gpu: int = DEFAULT_USE_GPU,
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore gold_preproc: bool = DEFAULT_GOLD_PREPROC,
silent: bool = True, silent: bool = True,
) -> Tuple[float, float, Dict[float, float]]: ) -> Tuple[float, float, Dict[float, float]]:
""" """
@ -104,30 +102,25 @@ def find_threshold(
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
evaluated thresholds. evaluated thresholds.
""" """
setup_gpu(use_gpu, silent=silent) setup_gpu(use_gpu, silent=silent)
data_path = util.ensure_path(data_path) data_path = util.ensure_path(data_path)
if not data_path.exists(): if not data_path.exists():
wasabi.msg.fail("Evaluation data not found", data_path, exits=1) msg.fail("Evaluation data not found", data_path, exits=1)
nlp = util.load_model(model) nlp = util.load_model(model)
if pipe_name not in nlp.component_names: if pipe_name not in nlp.component_names:
raise AttributeError( err = Errors.E001.format(name=pipe_name, opts=nlp.component_names)
Errors.E001.format(name=pipe_name, opts=nlp.component_names) raise AttributeError(err)
)
pipe = nlp.get_pipe(pipe_name) pipe = nlp.get_pipe(pipe_name)
if not hasattr(pipe, "scorer"): if not hasattr(pipe, "scorer"):
raise AttributeError(Errors.E1045) raise AttributeError(Errors.E1045)
if type(pipe) == TextCategorizer: if type(pipe) == TextCategorizer:
wasabi.msg.warn( msg.warn(
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of " "The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
"exclusive classes. All thresholds will yield the same results." "exclusive classes. All thresholds will yield the same results."
) )
if not silent: if not silent:
wasabi.msg.info( text = f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} trials."
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} " msg.info(text)
f"trials."
)
# Load evaluation corpus. # Load evaluation corpus.
corpus = Corpus(data_path, gold_preproc=gold_preproc) corpus = Corpus(data_path, gold_preproc=gold_preproc)
dev_dataset = list(corpus(nlp)) dev_dataset = list(corpus(nlp))
@ -155,9 +148,9 @@ def find_threshold(
RETURNS (Dict[str, Any]): Filtered dictionary. RETURNS (Dict[str, Any]): Filtered dictionary.
""" """
if keys[0] not in config: if keys[0] not in config:
wasabi.msg.fail( msg.fail(
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.", f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: " f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
f"{list(config.keys())}", f"{list(config.keys())}",
exits=1, exits=1,
) )
@ -172,7 +165,7 @@ def find_threshold(
config_keys_full = ["components", pipe_name, *config_keys] config_keys_full = ["components", pipe_name, *config_keys]
table_col_widths = (10, 10) table_col_widths = (10, 10)
thresholds = numpy.linspace(0, 1, n_trials) thresholds = numpy.linspace(0, 1, n_trials)
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths)) print(row(["Threshold", f"{scores_key}"], widths=table_col_widths))
for threshold in thresholds: for threshold in thresholds:
# Reload pipeline with overrides specifying the new threshold. # Reload pipeline with overrides specifying the new threshold.
nlp = util.load_model( nlp = util.load_model(
@ -191,34 +184,28 @@ def find_threshold(
"cfg", "cfg",
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold), set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
) )
eval_scores = nlp.evaluate(dev_dataset) eval_scores = nlp.evaluate(dev_dataset)
if scores_key not in eval_scores: if scores_key not in eval_scores:
wasabi.msg.fail( msg.fail(
title=f"Failed to look up score `{scores_key}` in evaluation results.", title=f"Failed to look up score `{scores_key}` in evaluation results.",
text=f"Make sure you specified the correct value for `scores_key`. The following scores are " text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
f"available: {list(eval_scores.keys())}", f"available: {list(eval_scores.keys())}",
exits=1, exits=1,
) )
scores[threshold] = eval_scores[scores_key] scores[threshold] = eval_scores[scores_key]
if not isinstance(scores[threshold], (float, int)): if not isinstance(scores[threshold], (float, int)):
wasabi.msg.fail( msg.fail(
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric " f"Returned score for key '{scores_key}' is not numeric. Threshold "
f"scores.", f"optimization only works for numeric scores.",
exits=1, exits=1,
) )
print( data = [round(threshold, 3), round(scores[threshold], 3)]
wasabi.row( print(row(data, widths=table_col_widths))
[round(threshold, 3), round(scores[threshold], 3)],
widths=table_col_widths,
)
)
best_threshold = max(scores.keys(), key=(lambda key: scores[key])) best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
# If all scores are identical, emit warning. # If all scores are identical, emit warning.
if len(set(scores.values())) == 1: if len(set(scores.values())) == 1:
wasabi.msg.warn( msg.warn(
title="All scores are identical. Verify that all settings are correct.", "All scores are identical. Verify that all settings are correct.",
text="" text=""
if ( if (
not isinstance(pipe, MultiLabel_TextCategorizer) not isinstance(pipe, MultiLabel_TextCategorizer)
@ -229,7 +216,7 @@ def find_threshold(
else: else:
if not silent: if not silent:
print( print(
f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}." f"\nBest threshold: {round(best_threshold, ndigits=4)} with "
f"{scores_key} value of {scores[best_threshold]}."
) )
return best_threshold, scores[best_threshold], scores return best_threshold, scores[best_threshold], scores

View File

@ -1,6 +1,6 @@
from typing import Optional, List, Tuple, Literal from typing import Optional, List, Tuple, Literal, cast
from pathlib import Path from pathlib import Path
from wasabi import Printer, diff_strings from wasabi import Printer, msg, diff_strings
from thinc.api import Config from thinc.api import Config
import srsly import srsly
import re import re
@ -25,16 +25,17 @@ OptimizationsType = Literal["efficiency", "accuracy"]
class InitValues: class InitValues:
""" """
Default values for initialization. Dedicated class to allow synchronized default values for init_config_cli() and Default values for initialization. Dedicated class to allow synchronized
init_config(), i.e. initialization calls via CLI respectively Python. default values for init_config_cli() and init_config(), i.e. initialization
calls via CLI respectively Python.
""" """
lang = "en" lang: str = "en"
pipeline = SimpleFrozenList(["tagger", "parser", "ner"]) pipeline: List[str] = SimpleFrozenList(["tagger", "parser", "ner"])
optimize = "efficiency" optimize: OptimizationsType = "efficiency"
gpu = False gpu: bool = False
pretraining = False pretraining: bool = False
force_overwrite = False force_overwrite: bool = False
@cli.subcommand( @cli.subcommand(
@ -68,10 +69,11 @@ def init_config_cli(
DOCS: https://spacy.io/api/cli#init-config DOCS: https://spacy.io/api/cli#init-config
""" """
is_stdout = output_file == "-" is_stdout = output_file == "-"
if not is_stdout and output_file.exists() and not force_overwrite: if not is_stdout:
msg = Printer() if output_file.exists() and not force_overwrite:
msg.fail( msg.fail(
"The provided output file already exists. To force overwriting the config file, set the --force or -F flag.", "The provided output file already exists. To force overwriting "
"the config file, set the --force or -F flag.",
exits=1, exits=1,
) )
config = init_config( config = init_config(
@ -239,7 +241,10 @@ def init_config(
def save_config( def save_config(
config: Config, output_file: Path, is_stdout: bool = False, silent: bool = False config: Config,
output_file: Path,
is_stdout: bool = False,
silent: bool = False,
) -> None: ) -> None:
no_print = is_stdout or silent no_print = is_stdout or silent
msg = Printer(no_print=no_print) msg = Printer(no_print=no_print)

View File

@ -5,15 +5,13 @@ from typing import Tuple, List, Dict, Any
import pkg_resources import pkg_resources
import time import time
from pathlib import Path from pathlib import Path
import spacy
import numpy
import pytest import pytest
import srsly import srsly
from click import NoSuchOption from click import NoSuchOption
from packaging.specifiers import SpecifierSet from packaging.specifiers import SpecifierSet
from thinc.api import Config, ConfigValidationError from thinc.api import Config, ConfigValidationError
import spacy
from spacy import about from spacy import about
from spacy.cli import info from spacy.cli import info
from spacy.cli._util import is_subpath_of, load_project_config, walk_directory from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
@ -745,13 +743,13 @@ def test_get_labels_from_model(factory_name, pipe_name):
def test_permitted_package_names(): def test_permitted_package_names():
# https://www.python.org/dev/peps/pep-0426/#name # https://www.python.org/dev/peps/pep-0426/#name
assert _is_permitted_package_name("Meine_Bäume") == False assert _is_permitted_package_name("Meine_Bäume") is False
assert _is_permitted_package_name("_package") == False assert _is_permitted_package_name("_package") is False
assert _is_permitted_package_name("package_") == False assert _is_permitted_package_name("package_") is False
assert _is_permitted_package_name(".package") == False assert _is_permitted_package_name(".package") is False
assert _is_permitted_package_name("package.") == False assert _is_permitted_package_name("package.") is False
assert _is_permitted_package_name("-package") == False assert _is_permitted_package_name("-package") is False
assert _is_permitted_package_name("package-") == False assert _is_permitted_package_name("package-") is False
def test_debug_data_compile_gold(): def test_debug_data_compile_gold():