mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-06 06:03:11 +03:00
357 lines
12 KiB
Python
357 lines
12 KiB
Python
from typing import Dict, Any, Union, List, Optional, Iterable
|
|
from typing import TYPE_CHECKING, overload
|
|
import sys
|
|
import shutil
|
|
from pathlib import Path
|
|
from wasabi import msg, Printer
|
|
import srsly
|
|
import hashlib
|
|
import typer
|
|
from click import NoSuchOption
|
|
from click.parser import split_arg_string
|
|
from typer.main import get_command
|
|
from contextlib import contextmanager
|
|
from thinc.api import ConfigValidationError, require_gpu
|
|
from thinc.util import gpu_is_available
|
|
from configparser import InterpolationError
|
|
import os
|
|
|
|
from ..compat import Literal
|
|
from ..util import import_file, registry, logger, ENV_VARS
|
|
|
|
if TYPE_CHECKING:
|
|
from pathy import FluidPath # noqa: F401
|
|
|
|
|
|
SDIST_SUFFIX = ".tar.gz"
|
|
WHEEL_SUFFIX = "-py3-none-any.whl"
|
|
|
|
PROJECT_FILE = "project.yml"
|
|
COMMAND = "python -m spacy"
|
|
NAME = "spacy"
|
|
HELP = """spaCy Command-line Interface
|
|
|
|
DOCS: https://spacy.io/api/cli
|
|
"""
|
|
PROJECT_HELP = f"""Command-line interface for spaCy projects and templates.
|
|
You'd typically start by cloning a project template to a local directory and
|
|
fetching its assets like datasets etc. See the project's {PROJECT_FILE} for the
|
|
available commands.
|
|
"""
|
|
DEBUG_HELP = """Suite of helpful commands for debugging and profiling. Includes
|
|
commands to check and validate your config files, training and evaluation data,
|
|
and custom model implementations.
|
|
"""
|
|
BENCHMARK_HELP = """Commands for benchmarking pipelines."""
|
|
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
|
|
|
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
|
# keep the names short, but not needed at the moment.
|
|
Arg = typer.Argument
|
|
Opt = typer.Option
|
|
|
|
app = typer.Typer(name=NAME, help=HELP)
|
|
benchmark_cli = typer.Typer(name="benchmark", help=BENCHMARK_HELP, no_args_is_help=True)
|
|
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
|
|
debug_cli = typer.Typer(name="debug", help=DEBUG_HELP, no_args_is_help=True)
|
|
init_cli = typer.Typer(name="init", help=INIT_HELP, no_args_is_help=True)
|
|
|
|
app.add_typer(project_cli)
|
|
app.add_typer(debug_cli)
|
|
app.add_typer(benchmark_cli)
|
|
app.add_typer(init_cli)
|
|
|
|
|
|
def setup_cli() -> None:
|
|
# Make sure the entry-point for CLI runs, so that they get imported.
|
|
registry.cli.get_all()
|
|
# Ensure that the help messages always display the correct prompt
|
|
command = get_command(app)
|
|
command(prog_name=COMMAND)
|
|
|
|
|
|
def parse_config_overrides(
|
|
args: List[str], env_var: Optional[str] = ENV_VARS.CONFIG_OVERRIDES
|
|
) -> Dict[str, Any]:
|
|
"""Generate a dictionary of config overrides based on the extra arguments
|
|
provided on the CLI, e.g. --training.batch_size to override
|
|
"training.batch_size". Arguments without a "." are considered invalid,
|
|
since the config only allows top-level sections to exist.
|
|
|
|
env_vars (Optional[str]): Optional environment variable to read from.
|
|
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
|
|
"""
|
|
env_string = os.environ.get(env_var, "") if env_var else ""
|
|
env_overrides = _parse_overrides(split_arg_string(env_string))
|
|
cli_overrides = _parse_overrides(args, is_cli=True)
|
|
if cli_overrides:
|
|
keys = [k for k in cli_overrides if k not in env_overrides]
|
|
logger.debug("Config overrides from CLI: %s", keys)
|
|
if env_overrides:
|
|
logger.debug("Config overrides from env variables: %s", list(env_overrides))
|
|
return {**cli_overrides, **env_overrides}
|
|
|
|
|
|
def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
|
|
result = {}
|
|
while args:
|
|
opt = args.pop(0)
|
|
err = f"Invalid config override '{opt}'"
|
|
if opt.startswith("--"): # new argument
|
|
orig_opt = opt
|
|
opt = opt.replace("--", "")
|
|
if "." not in opt:
|
|
if is_cli:
|
|
raise NoSuchOption(orig_opt)
|
|
else:
|
|
msg.fail(f"{err}: can't override top-level sections", exits=1)
|
|
if "=" in opt: # we have --opt=value
|
|
opt, value = opt.split("=", 1)
|
|
opt = opt.replace("-", "_")
|
|
else:
|
|
if not args or args[0].startswith("--"): # flag with no value
|
|
value = "true"
|
|
else:
|
|
value = args.pop(0)
|
|
result[opt] = _parse_override(value)
|
|
else:
|
|
msg.fail(f"{err}: name should start with --", exits=1)
|
|
return result
|
|
|
|
|
|
def _parse_override(value: Any) -> Any:
|
|
# Just like we do in the config, we're calling json.loads on the
|
|
# values. But since they come from the CLI, it'd be unintuitive to
|
|
# explicitly mark strings with escaped quotes. So we're working
|
|
# around that here by falling back to a string if parsing fails.
|
|
# TODO: improve logic to handle simple types like list of strings?
|
|
try:
|
|
return srsly.json_loads(value)
|
|
except ValueError:
|
|
return str(value)
|
|
|
|
|
|
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
|
|
"""Get the hash for a JSON-serializable object.
|
|
|
|
data: The data to hash.
|
|
exclude (Iterable[str]): Top-level keys to exclude if data is a dict.
|
|
RETURNS (str): The hash.
|
|
"""
|
|
if isinstance(data, dict):
|
|
data = {k: v for k, v in data.items() if k not in exclude}
|
|
data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8")
|
|
return hashlib.md5(data_str).hexdigest()
|
|
|
|
|
|
def get_checksum(path: Union[Path, str]) -> str:
|
|
"""Get the checksum for a file or directory given its file path. If a
|
|
directory path is provided, this uses all files in that directory.
|
|
|
|
path (Union[Path, str]): The file or directory path.
|
|
RETURNS (str): The checksum.
|
|
"""
|
|
path = Path(path)
|
|
if not (path.is_file() or path.is_dir()):
|
|
msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
|
|
if path.is_file():
|
|
return hashlib.md5(Path(path).read_bytes()).hexdigest()
|
|
else:
|
|
# TODO: this is currently pretty slow
|
|
dir_checksum = hashlib.md5()
|
|
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
|
dir_checksum.update(sub_file.read_bytes())
|
|
return dir_checksum.hexdigest()
|
|
|
|
|
|
@contextmanager
|
|
def show_validation_error(
|
|
file_path: Optional[Union[str, Path]] = None,
|
|
*,
|
|
title: Optional[str] = None,
|
|
desc: str = "",
|
|
show_config: Optional[bool] = None,
|
|
hint_fill: bool = True,
|
|
):
|
|
"""Helper to show custom config validation errors on the CLI.
|
|
|
|
file_path (str / Path): Optional file path of config file, used in hints.
|
|
title (str): Override title of custom formatted error.
|
|
desc (str): Override description of custom formatted error.
|
|
show_config (bool): Whether to output the config the error refers to.
|
|
hint_fill (bool): Show hint about filling config.
|
|
"""
|
|
try:
|
|
yield
|
|
except ConfigValidationError as e:
|
|
title = title if title is not None else e.title
|
|
if e.desc:
|
|
desc = f"{e.desc}" if not desc else f"{e.desc}\n\n{desc}"
|
|
# Re-generate a new error object with overrides
|
|
err = e.from_error(e, title="", desc=desc, show_config=show_config)
|
|
msg.fail(title)
|
|
print(err.text.strip())
|
|
if hint_fill and "value_error.missing" in err.error_types:
|
|
config_path = (
|
|
file_path
|
|
if file_path is not None and str(file_path) != "-"
|
|
else "config.cfg"
|
|
)
|
|
msg.text(
|
|
"If your config contains missing values, you can run the 'init "
|
|
"fill-config' command to fill in all the defaults, if possible:",
|
|
spaced=True,
|
|
)
|
|
print(f"{COMMAND} init fill-config {config_path} {config_path} \n")
|
|
sys.exit(1)
|
|
except InterpolationError as e:
|
|
msg.fail("Config validation error", e, exits=1)
|
|
|
|
|
|
def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
|
"""Helper to import Python file provided in training commands / commands
|
|
using the config. This makes custom registered functions available.
|
|
"""
|
|
if code_path is not None:
|
|
if not Path(code_path).exists():
|
|
msg.fail("Path to Python code not found", code_path, exits=1)
|
|
try:
|
|
import_file("python_code", code_path)
|
|
except Exception as e:
|
|
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
|
|
|
|
|
def upload_file(src: Path, dest: Union[str, "FluidPath"]) -> None:
|
|
"""Upload a file.
|
|
|
|
src (Path): The source path.
|
|
url (str): The destination URL to upload to.
|
|
"""
|
|
import smart_open
|
|
|
|
# Create parent directories for local paths
|
|
if isinstance(dest, Path):
|
|
if not dest.parent.exists():
|
|
dest.parent.mkdir(parents=True)
|
|
|
|
dest = str(dest)
|
|
with smart_open.open(dest, mode="wb") as output_file:
|
|
with src.open(mode="rb") as input_file:
|
|
output_file.write(input_file.read())
|
|
|
|
|
|
def download_file(
|
|
src: Union[str, "FluidPath"], dest: Path, *, force: bool = False
|
|
) -> None:
|
|
"""Download a file using smart_open.
|
|
|
|
url (str): The URL of the file.
|
|
dest (Path): The destination path.
|
|
force (bool): Whether to force download even if file exists.
|
|
If False, the download will be skipped.
|
|
"""
|
|
import smart_open
|
|
|
|
if dest.exists() and not force:
|
|
return None
|
|
src = str(src)
|
|
with smart_open.open(src, mode="rb", compression="disable") as input_file:
|
|
with dest.open(mode="wb") as output_file:
|
|
shutil.copyfileobj(input_file, output_file)
|
|
|
|
|
|
def ensure_pathy(path):
|
|
"""Temporary helper to prevent importing Pathy globally (which can cause
|
|
slow and annoying Google Cloud warning)."""
|
|
from pathy import Pathy # noqa: F811
|
|
|
|
return Pathy.fluid(path)
|
|
|
|
|
|
@overload
|
|
def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]:
|
|
...
|
|
|
|
|
|
@overload
|
|
def string_to_list(value: str, intify: Literal[True]) -> List[int]:
|
|
...
|
|
|
|
|
|
def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[int]]:
|
|
"""Parse a comma-separated string to a list and account for various
|
|
formatting options. Mostly used to handle CLI arguments that take a list of
|
|
comma-separated values.
|
|
|
|
value (str): The value to parse.
|
|
intify (bool): Whether to convert values to ints.
|
|
RETURNS (Union[List[str], List[int]]): A list of strings or ints.
|
|
"""
|
|
if not value:
|
|
return [] # type: ignore[return-value]
|
|
if value.startswith("[") and value.endswith("]"):
|
|
value = value[1:-1]
|
|
result = []
|
|
for p in value.split(","):
|
|
p = p.strip()
|
|
if p.startswith("'") and p.endswith("'"):
|
|
p = p[1:-1]
|
|
if p.startswith('"') and p.endswith('"'):
|
|
p = p[1:-1]
|
|
p = p.strip()
|
|
if intify:
|
|
p = int(p) # type: ignore[assignment]
|
|
result.append(p)
|
|
return result
|
|
|
|
|
|
def setup_gpu(use_gpu: int, silent=None) -> None:
|
|
"""Configure the GPU and log info."""
|
|
if silent is None:
|
|
local_msg = Printer()
|
|
else:
|
|
local_msg = Printer(no_print=silent, pretty=not silent)
|
|
if use_gpu >= 0:
|
|
local_msg.info(f"Using GPU: {use_gpu}")
|
|
require_gpu(use_gpu)
|
|
else:
|
|
local_msg.info("Using CPU")
|
|
if gpu_is_available():
|
|
local_msg.info("To switch to GPU 0, use the option: --gpu-id 0")
|
|
|
|
|
|
def walk_directory(path: Path, suffix: Optional[str] = None) -> List[Path]:
|
|
"""Given a directory and a suffix, recursively find all files matching the suffix.
|
|
Directories or files with names beginning with a . are ignored, but hidden flags on
|
|
filesystems are not checked.
|
|
When provided with a suffix `None`, there is no suffix-based filtering."""
|
|
if not path.is_dir():
|
|
return [path]
|
|
paths = [path]
|
|
locs = []
|
|
seen = set()
|
|
for path in paths:
|
|
if str(path) in seen:
|
|
continue
|
|
seen.add(str(path))
|
|
if path.parts[-1].startswith("."):
|
|
continue
|
|
elif path.is_dir():
|
|
paths.extend(path.iterdir())
|
|
elif suffix is not None and not path.parts[-1].endswith(suffix):
|
|
continue
|
|
else:
|
|
locs.append(path)
|
|
# It's good to sort these, in case the ordering messes up cache.
|
|
locs.sort()
|
|
return locs
|
|
|
|
|
|
def _format_number(number: Union[int, float], ndigits: int = 2) -> str:
|
|
"""Formats a number (float or int) rounding to `ndigits`, without truncating trailing 0s,
|
|
as happens with `round(number, ndigits)`"""
|
|
if isinstance(number, float):
|
|
return f"{number:.{ndigits}f}"
|
|
else:
|
|
return str(number)
|