mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Tidy up and fix types
This commit is contained in:
parent
dbfc9f688f
commit
5f049b2de6
|
@ -94,7 +94,7 @@ def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
|
|||
opt = opt.replace("--", "")
|
||||
if "." not in opt:
|
||||
if is_cli:
|
||||
raise radicli.CliParseError(f"unrecognized argument: {orig_opt}")
|
||||
raise radicli.CliParserError(f"unrecognized argument: {orig_opt}")
|
||||
else:
|
||||
msg.fail(f"{err}: can't override top-level sections", exits=1)
|
||||
if "=" in opt: # we have --opt=value
|
||||
|
|
|
@ -53,8 +53,9 @@ def assemble_cli(
|
|||
with nlp.select_pipes(disable=[*sourced]):
|
||||
nlp.initialize()
|
||||
msg.good("Initialized pipeline")
|
||||
msg.divider("Serializing to disk")
|
||||
if output_path is not None and not output_path.exists():
|
||||
output_path.mkdir(parents=True)
|
||||
msg.good(f"Created output directory: {output_path}")
|
||||
nlp.to_disk(output_path)
|
||||
if output_path is not None:
|
||||
msg.divider("Serializing to disk")
|
||||
if not output_path.exists():
|
||||
output_path.mkdir(parents=True)
|
||||
msg.good(f"Created output directory: {output_path}")
|
||||
nlp.to_disk(output_path)
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal
|
||||
from enum import Enum
|
||||
from typing import Callable, Iterable, Mapping, Optional, Any, Union, Literal, cast
|
||||
from pathlib import Path
|
||||
from wasabi import Printer
|
||||
import srsly
|
||||
|
@ -19,6 +18,8 @@ from ..training.converters import conllu_to_docs
|
|||
# matched by file extension and content. To add a converter, add a new
|
||||
# entry to this dict with the file extension mapped to the converter function
|
||||
# imported from /converters.
|
||||
ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
|
||||
FileTypesType = Literal["json", "spacy"]
|
||||
|
||||
CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
|
||||
"conllubio": conllu_to_docs,
|
||||
|
@ -28,19 +29,11 @@ CONVERTERS: Mapping[str, Callable[..., Iterable[Doc]]] = {
|
|||
"iob": iob_to_docs,
|
||||
"json": json_to_docs,
|
||||
}
|
||||
AUTO = "auto"
|
||||
ConvertersType = Literal["auto", "conllubio", "conllu", "conll", "ner", "iob", "json"]
|
||||
|
||||
|
||||
AUTO: ConvertersType = "auto"
|
||||
# File types that can be written to stdout
|
||||
FILE_TYPES_STDOUT = ("json",)
|
||||
|
||||
|
||||
class FileTypes(str, Enum):
|
||||
json = "json"
|
||||
spacy = "spacy"
|
||||
|
||||
|
||||
@cli.command(
|
||||
"convert",
|
||||
# fmt: off
|
||||
|
@ -61,7 +54,7 @@ class FileTypes(str, Enum):
|
|||
def convert_cli(
|
||||
input_path: ExistingPathOrDash,
|
||||
output_dir: ExistingDirPathOrDash = "-",
|
||||
file_type: Literal["json", "spacy"] = "spacy",
|
||||
file_type: FileTypesType = "spacy",
|
||||
n_sents: int = 1,
|
||||
seg_sents: bool = False,
|
||||
model: Optional[str] = None,
|
||||
|
@ -110,7 +103,7 @@ def convert(
|
|||
input_path: Path,
|
||||
output_dir: Union[str, Path],
|
||||
*,
|
||||
file_type: str = "json",
|
||||
file_type: FileTypesType = "json",
|
||||
n_sents: int = 1,
|
||||
seg_sents: bool = False,
|
||||
model: Optional[str] = None,
|
||||
|
@ -173,14 +166,16 @@ def convert(
|
|||
msg.good(f"Generated output file ({len_docs} documents): {output_file}")
|
||||
|
||||
|
||||
def _print_docs_to_stdout(data: Any, output_type: str) -> None:
|
||||
def _print_docs_to_stdout(data: Any, output_type: FileTypesType) -> None:
|
||||
if output_type == "json":
|
||||
srsly.write_json("-", data)
|
||||
else:
|
||||
sys.stdout.buffer.write(data)
|
||||
|
||||
|
||||
def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
|
||||
def _write_docs_to_file(
|
||||
data: Any, output_file: Path, output_type: FileTypesType
|
||||
) -> None:
|
||||
if not output_file.parent.exists():
|
||||
output_file.parent.mkdir(parents=True)
|
||||
if output_type == "json":
|
||||
|
@ -190,7 +185,7 @@ def _write_docs_to_file(data: Any, output_file: Path, output_type: str) -> None:
|
|||
file_.write(data)
|
||||
|
||||
|
||||
def autodetect_ner_format(input_data: str) -> Optional[str]:
|
||||
def autodetect_ner_format(input_data: str) -> Optional[ConvertersType]:
|
||||
# guess format from the first 20 lines
|
||||
lines = input_data.split("\n")[:20]
|
||||
format_guesses = {"ner": 0, "iob": 0}
|
||||
|
@ -213,7 +208,7 @@ def verify_cli_args(
|
|||
msg: Printer,
|
||||
input_path: Path,
|
||||
output_dir: Union[str, Path],
|
||||
file_type: str,
|
||||
file_type: FileTypesType,
|
||||
converter: str,
|
||||
ner_map: Optional[Path],
|
||||
):
|
||||
|
@ -236,7 +231,7 @@ def verify_cli_args(
|
|||
msg.fail(f"Can't find converter for {converter}", exits=1)
|
||||
|
||||
|
||||
def _get_converter(msg, converter, input_path: Path):
|
||||
def _get_converter(msg: Printer, converter: ConvertersType, input_path: Path) -> str:
|
||||
if input_path.is_dir():
|
||||
if converter == AUTO:
|
||||
input_locs = walk_directory(input_path, suffix=None)
|
||||
|
@ -265,4 +260,4 @@ def _get_converter(msg, converter, input_path: Path):
|
|||
"Conversion may not succeed. "
|
||||
"See https://spacy.io/api/cli#convert"
|
||||
)
|
||||
return converter
|
||||
return cast(str, converter)
|
||||
|
|
|
@ -4,7 +4,7 @@ from pathlib import Path
|
|||
import logging
|
||||
from typing import Optional, Tuple, Any, Dict, List
|
||||
import numpy
|
||||
import wasabi.tables
|
||||
from wasabi import msg, row
|
||||
from radicli import Arg, ExistingPath, ExistingFilePath
|
||||
|
||||
from ..pipeline import TextCategorizer, MultiLabel_TextCategorizer
|
||||
|
@ -13,11 +13,9 @@ from ..training import Corpus
|
|||
from ._util import cli, import_code, setup_gpu
|
||||
from .. import util
|
||||
|
||||
_DEFAULTS = {
|
||||
"n_trials": 11,
|
||||
"use_gpu": -1,
|
||||
"gold_preproc": False,
|
||||
}
|
||||
DEFAULT_N_TRIALS: int = 11
|
||||
DEFAULT_USE_GPU: int = -1
|
||||
DEFAULT_GOLD_PREPROC: bool = False
|
||||
|
||||
|
||||
@cli.command(
|
||||
|
@ -41,10 +39,10 @@ def find_threshold_cli(
|
|||
pipe_name: str,
|
||||
threshold_key: str,
|
||||
scores_key: str,
|
||||
n_trials: int = _DEFAULTS["n_trials"],
|
||||
n_trials: int = DEFAULT_N_TRIALS,
|
||||
code_path: Optional[ExistingFilePath] = None,
|
||||
use_gpu: int = _DEFAULTS["use_gpu"],
|
||||
gold_preproc: bool = _DEFAULTS["gold_preproc"],
|
||||
use_gpu: int = DEFAULT_USE_GPU,
|
||||
gold_preproc: bool = DEFAULT_GOLD_PREPROC,
|
||||
verbose: bool = False,
|
||||
):
|
||||
"""
|
||||
|
@ -83,9 +81,9 @@ def find_threshold(
|
|||
threshold_key: str,
|
||||
scores_key: str,
|
||||
*,
|
||||
n_trials: int = _DEFAULTS["n_trials"], # type: ignore
|
||||
use_gpu: int = _DEFAULTS["use_gpu"], # type: ignore
|
||||
gold_preproc: bool = _DEFAULTS["gold_preproc"], # type: ignore
|
||||
n_trials: int = DEFAULT_N_TRIALS,
|
||||
use_gpu: int = DEFAULT_USE_GPU,
|
||||
gold_preproc: bool = DEFAULT_GOLD_PREPROC,
|
||||
silent: bool = True,
|
||||
) -> Tuple[float, float, Dict[float, float]]:
|
||||
"""
|
||||
|
@ -104,30 +102,25 @@ def find_threshold(
|
|||
RETURNS (Tuple[float, float, Dict[float, float]]): Best found threshold, the corresponding score, scores for all
|
||||
evaluated thresholds.
|
||||
"""
|
||||
|
||||
setup_gpu(use_gpu, silent=silent)
|
||||
data_path = util.ensure_path(data_path)
|
||||
if not data_path.exists():
|
||||
wasabi.msg.fail("Evaluation data not found", data_path, exits=1)
|
||||
msg.fail("Evaluation data not found", data_path, exits=1)
|
||||
nlp = util.load_model(model)
|
||||
|
||||
if pipe_name not in nlp.component_names:
|
||||
raise AttributeError(
|
||||
Errors.E001.format(name=pipe_name, opts=nlp.component_names)
|
||||
)
|
||||
err = Errors.E001.format(name=pipe_name, opts=nlp.component_names)
|
||||
raise AttributeError(err)
|
||||
pipe = nlp.get_pipe(pipe_name)
|
||||
if not hasattr(pipe, "scorer"):
|
||||
raise AttributeError(Errors.E1045)
|
||||
if type(pipe) == TextCategorizer:
|
||||
wasabi.msg.warn(
|
||||
msg.warn(
|
||||
"The `textcat` component doesn't use a threshold as it's not applicable to the concept of "
|
||||
"exclusive classes. All thresholds will yield the same results."
|
||||
)
|
||||
if not silent:
|
||||
wasabi.msg.info(
|
||||
title=f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} "
|
||||
f"trials."
|
||||
)
|
||||
text = f"Optimizing for {scores_key} for component '{pipe_name}' with {n_trials} trials."
|
||||
msg.info(text)
|
||||
# Load evaluation corpus.
|
||||
corpus = Corpus(data_path, gold_preproc=gold_preproc)
|
||||
dev_dataset = list(corpus(nlp))
|
||||
|
@ -155,9 +148,9 @@ def find_threshold(
|
|||
RETURNS (Dict[str, Any]): Filtered dictionary.
|
||||
"""
|
||||
if keys[0] not in config:
|
||||
wasabi.msg.fail(
|
||||
title=f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
|
||||
text=f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
|
||||
msg.fail(
|
||||
f"Failed to look up `{full_key}` in config: sub-key {[keys[0]]} not found.",
|
||||
f"Make sure you specified {[keys[0]]} correctly. The following sub-keys are available instead: "
|
||||
f"{list(config.keys())}",
|
||||
exits=1,
|
||||
)
|
||||
|
@ -172,7 +165,7 @@ def find_threshold(
|
|||
config_keys_full = ["components", pipe_name, *config_keys]
|
||||
table_col_widths = (10, 10)
|
||||
thresholds = numpy.linspace(0, 1, n_trials)
|
||||
print(wasabi.tables.row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
||||
print(row(["Threshold", f"{scores_key}"], widths=table_col_widths))
|
||||
for threshold in thresholds:
|
||||
# Reload pipeline with overrides specifying the new threshold.
|
||||
nlp = util.load_model(
|
||||
|
@ -191,34 +184,28 @@ def find_threshold(
|
|||
"cfg",
|
||||
set_nested_item(getattr(pipe, "cfg"), config_keys, threshold),
|
||||
)
|
||||
|
||||
eval_scores = nlp.evaluate(dev_dataset)
|
||||
if scores_key not in eval_scores:
|
||||
wasabi.msg.fail(
|
||||
msg.fail(
|
||||
title=f"Failed to look up score `{scores_key}` in evaluation results.",
|
||||
text=f"Make sure you specified the correct value for `scores_key`. The following scores are "
|
||||
f"available: {list(eval_scores.keys())}",
|
||||
exits=1,
|
||||
)
|
||||
scores[threshold] = eval_scores[scores_key]
|
||||
|
||||
if not isinstance(scores[threshold], (float, int)):
|
||||
wasabi.msg.fail(
|
||||
f"Returned score for key '{scores_key}' is not numeric. Threshold optimization only works for numeric "
|
||||
f"scores.",
|
||||
msg.fail(
|
||||
f"Returned score for key '{scores_key}' is not numeric. Threshold "
|
||||
f"optimization only works for numeric scores.",
|
||||
exits=1,
|
||||
)
|
||||
print(
|
||||
wasabi.row(
|
||||
[round(threshold, 3), round(scores[threshold], 3)],
|
||||
widths=table_col_widths,
|
||||
)
|
||||
)
|
||||
data = [round(threshold, 3), round(scores[threshold], 3)]
|
||||
print(row(data, widths=table_col_widths))
|
||||
best_threshold = max(scores.keys(), key=(lambda key: scores[key]))
|
||||
# If all scores are identical, emit warning.
|
||||
if len(set(scores.values())) == 1:
|
||||
wasabi.msg.warn(
|
||||
title="All scores are identical. Verify that all settings are correct.",
|
||||
msg.warn(
|
||||
"All scores are identical. Verify that all settings are correct.",
|
||||
text=""
|
||||
if (
|
||||
not isinstance(pipe, MultiLabel_TextCategorizer)
|
||||
|
@ -229,7 +216,7 @@ def find_threshold(
|
|||
else:
|
||||
if not silent:
|
||||
print(
|
||||
f"\nBest threshold: {round(best_threshold, ndigits=4)} with {scores_key} value of {scores[best_threshold]}."
|
||||
f"\nBest threshold: {round(best_threshold, ndigits=4)} with "
|
||||
f"{scores_key} value of {scores[best_threshold]}."
|
||||
)
|
||||
|
||||
return best_threshold, scores[best_threshold], scores
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional, List, Tuple, Literal
|
||||
from typing import Optional, List, Tuple, Literal, cast
|
||||
from pathlib import Path
|
||||
from wasabi import Printer, diff_strings
|
||||
from wasabi import Printer, msg, diff_strings
|
||||
from thinc.api import Config
|
||||
import srsly
|
||||
import re
|
||||
|
@ -25,16 +25,17 @@ OptimizationsType = Literal["efficiency", "accuracy"]
|
|||
|
||||
class InitValues:
|
||||
"""
|
||||
Default values for initialization. Dedicated class to allow synchronized default values for init_config_cli() and
|
||||
init_config(), i.e. initialization calls via CLI respectively Python.
|
||||
Default values for initialization. Dedicated class to allow synchronized
|
||||
default values for init_config_cli() and init_config(), i.e. initialization
|
||||
calls via CLI respectively Python.
|
||||
"""
|
||||
|
||||
lang = "en"
|
||||
pipeline = SimpleFrozenList(["tagger", "parser", "ner"])
|
||||
optimize = "efficiency"
|
||||
gpu = False
|
||||
pretraining = False
|
||||
force_overwrite = False
|
||||
lang: str = "en"
|
||||
pipeline: List[str] = SimpleFrozenList(["tagger", "parser", "ner"])
|
||||
optimize: OptimizationsType = "efficiency"
|
||||
gpu: bool = False
|
||||
pretraining: bool = False
|
||||
force_overwrite: bool = False
|
||||
|
||||
|
||||
@cli.subcommand(
|
||||
|
@ -68,12 +69,13 @@ def init_config_cli(
|
|||
DOCS: https://spacy.io/api/cli#init-config
|
||||
"""
|
||||
is_stdout = output_file == "-"
|
||||
if not is_stdout and output_file.exists() and not force_overwrite:
|
||||
msg = Printer()
|
||||
msg.fail(
|
||||
"The provided output file already exists. To force overwriting the config file, set the --force or -F flag.",
|
||||
exits=1,
|
||||
)
|
||||
if not is_stdout:
|
||||
if output_file.exists() and not force_overwrite:
|
||||
msg.fail(
|
||||
"The provided output file already exists. To force overwriting "
|
||||
"the config file, set the --force or -F flag.",
|
||||
exits=1,
|
||||
)
|
||||
config = init_config(
|
||||
lang=lang,
|
||||
pipeline=pipeline,
|
||||
|
@ -239,7 +241,10 @@ def init_config(
|
|||
|
||||
|
||||
def save_config(
|
||||
config: Config, output_file: Path, is_stdout: bool = False, silent: bool = False
|
||||
config: Config,
|
||||
output_file: Path,
|
||||
is_stdout: bool = False,
|
||||
silent: bool = False,
|
||||
) -> None:
|
||||
no_print = is_stdout or silent
|
||||
msg = Printer(no_print=no_print)
|
||||
|
|
|
@ -5,15 +5,13 @@ from typing import Tuple, List, Dict, Any
|
|||
import pkg_resources
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import spacy
|
||||
import numpy
|
||||
import pytest
|
||||
import srsly
|
||||
from click import NoSuchOption
|
||||
from packaging.specifiers import SpecifierSet
|
||||
from thinc.api import Config, ConfigValidationError
|
||||
|
||||
import spacy
|
||||
from spacy import about
|
||||
from spacy.cli import info
|
||||
from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
|
||||
|
@ -745,13 +743,13 @@ def test_get_labels_from_model(factory_name, pipe_name):
|
|||
|
||||
def test_permitted_package_names():
|
||||
# https://www.python.org/dev/peps/pep-0426/#name
|
||||
assert _is_permitted_package_name("Meine_Bäume") == False
|
||||
assert _is_permitted_package_name("_package") == False
|
||||
assert _is_permitted_package_name("package_") == False
|
||||
assert _is_permitted_package_name(".package") == False
|
||||
assert _is_permitted_package_name("package.") == False
|
||||
assert _is_permitted_package_name("-package") == False
|
||||
assert _is_permitted_package_name("package-") == False
|
||||
assert _is_permitted_package_name("Meine_Bäume") is False
|
||||
assert _is_permitted_package_name("_package") is False
|
||||
assert _is_permitted_package_name("package_") is False
|
||||
assert _is_permitted_package_name(".package") is False
|
||||
assert _is_permitted_package_name("package.") is False
|
||||
assert _is_permitted_package_name("-package") is False
|
||||
assert _is_permitted_package_name("package-") is False
|
||||
|
||||
|
||||
def test_debug_data_compile_gold():
|
||||
|
|
Loading…
Reference in New Issue
Block a user