Update CLI commans to use one shared util file

This commit is contained in:
Ines Montani 2020-07-10 17:57:40 +02:00
parent 240e0a62ca
commit 73332ddb67
19 changed files with 118 additions and 68 deletions

View File

@ -1,6 +1,6 @@
from wasabi import msg
from ._app import app, setup_cli # noqa: F401
from ._util import app, setup_cli # noqa: F401
# These are the actual functions, NOT the wrapped CLI commands. The CLI commands
# are registered automatically and won't have to be imported here.

View File

@ -1,31 +0,0 @@
import typer
from typer.main import get_command
COMMAND = "python -m spacy"
NAME = "spacy"
HELP = """spaCy Command-line Interface
DOCS: https://spacy.io/api/cli
"""
PROJECT_HELP = f"""Command-line interface for spaCy projects and working with
project templates. You'd typically start by cloning a project template to a local
directory and fetching its assets like datasets etc. See the project's
project.yml for the available commands.
"""
app = typer.Typer(name=NAME, help=HELP)
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
app.add_typer(project_cli)
# Wrappers for Typer's annotations. Initially created to set defaults and to
# keep the names short, but not needed at the moment.
Arg = typer.Argument
Opt = typer.Option
def setup_cli() -> None:
# Ensure that the help messages always display the correct prompt
command = get_command(app)
command(prog_name=COMMAND)

View File

@ -1,14 +1,76 @@
from typing import Dict, Any, Union
from typing import Dict, Any, Union, List
from pathlib import Path
from wasabi import msg
import srsly
import hashlib
import typer
from typer.main import get_command
from ...schemas import ProjectConfigSchema, validate
from ..schemas import ProjectConfigSchema, validate
PROJECT_FILE = "project.yml"
PROJECT_LOCK = "project.lock"
COMMAND = "python -m spacy"
NAME = "spacy"
HELP = """spaCy Command-line Interface
DOCS: https://spacy.io/api/cli
"""
PROJECT_HELP = f"""Command-line interface for spaCy projects and working with
project templates. You'd typically start by cloning a project template to a local
directory and fetching its assets like datasets etc. See the project's
{PROJECT_FILE} for the available commands.
"""
# Wrappers for Typer's annotations. Initially created to set defaults and to
# keep the names short, but not needed at the moment.
Arg = typer.Argument
Opt = typer.Option
app = typer.Typer(name=NAME, help=HELP)
project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True)
app.add_typer(project_cli)
def setup_cli() -> None:
# Ensure that the help messages always display the correct prompt
command = get_command(app)
command(prog_name=COMMAND)
def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
"""Generate a dictionary of config overrides based on the extra arguments
provided on the CLI, e.g. --training.batch_size to override
"training.batch_size". Arguments without a "." are considered invalid,
since the config only allows top-level sections to exist.
args (List[str]): The extra arguments from the command line.
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
"""
result = {}
while args:
opt = args.pop(0)
err = f"Invalid config override '{opt}'"
if opt.startswith("--"): # new argument
opt = opt.replace("--", "")
if "." not in opt:
msg.fail(f"{err}: can't override top-level section", exits=1)
if not args or args[0].startswith("--"): # flag with no value
value = True
else:
value = args.pop(0)
# Just like we do in the config, we're calling json.loads on the
# values. But since they come from the CLI, it'd b unintuitive to
# explicitly mark strings with escaped quotes. So we're working
# around that here by falling back to a string if parsing fails.
try:
result[opt] = srsly.json_loads(value)
except ValueError:
result[opt] = str(value)
else:
msg.fail(f"{err}: options need to start with --", exits=1)
return result
def load_project_config(path: Path) -> Dict[str, Any]:

View File

@ -6,7 +6,7 @@ import srsly
import re
import sys
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from ..gold import docs_to_json
from ..tokens import DocBin
from ..gold.converters import iob2docs, conll_ner2docs, json2docs, conllu2docs

View File

@ -5,7 +5,7 @@ import sys
import srsly
from wasabi import Printer, MESSAGES
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from ..gold import Corpus, Example
from ..syntax import nonproj
from ..language import Language

View File

@ -4,7 +4,7 @@ import sys
from wasabi import msg
import typer
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from .. import about
from ..util import is_package, get_base_version, run_command

View File

@ -8,7 +8,7 @@ from thinc.api import require_gpu, fix_random_seed
from ..gold import Corpus
from ..tokens import Doc
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from ..scorer import Scorer
from .. import util
from .. import displacy

View File

@ -4,7 +4,7 @@ from pathlib import Path
from wasabi import Printer
import srsly
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from .. import util
from .. import about

View File

@ -12,7 +12,7 @@ import srsly
import warnings
from wasabi import Printer
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from ..vectors import Vectors
from ..errors import Errors, Warnings
from ..language import Language

View File

@ -5,7 +5,7 @@ from wasabi import Printer, get_raw_input
import srsly
import sys
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from ..schemas import validate, ModelMetaSchema
from .. import util
from .. import about

View File

@ -12,7 +12,7 @@ from wasabi import msg
import srsly
from functools import partial
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from ..errors import Errors
from ..ml.models.multi_task import build_cloze_multi_task_model
from ..ml.models.multi_task import build_cloze_characters_multi_task_model

View File

@ -8,7 +8,7 @@ import sys
import itertools
from wasabi import msg, Printer
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt
from ..language import Language
from ..util import load_model

View File

@ -7,8 +7,7 @@ import re
import shutil
from ...util import ensure_path, working_dir
from .._app import project_cli, Arg
from .util import PROJECT_FILE, load_project_config, get_checksum
from .._util import project_cli, Arg, PROJECT_FILE, load_project_config, get_checksum
# TODO: find a solution for caches

View File

@ -7,8 +7,7 @@ import re
from ... import about
from ...util import ensure_path, run_command, make_tempdir
from .._app import project_cli, Arg, Opt, COMMAND
from .util import PROJECT_FILE
from .._util import project_cli, Arg, Opt, COMMAND, PROJECT_FILE
@project_cli.command("clone")

View File

@ -5,8 +5,8 @@ import subprocess
from pathlib import Path
from wasabi import msg
from .util import PROJECT_FILE, load_project_config, get_hash
from .._app import project_cli, Arg, Opt, NAME, COMMAND
from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli
from .._util import Arg, Opt, NAME, COMMAND
from ...util import working_dir, split_command, join_command, run_command

View File

@ -5,9 +5,8 @@ import sys
import srsly
from ...util import working_dir, run_command, split_command, is_cwd, join_command
from .._app import project_cli, Arg, Opt, COMMAND
from .util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
from .util import get_checksum
from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash
from .._util import get_checksum, project_cli, Arg, Opt, COMMAND
@project_cli.command("run")

View File

@ -1,4 +1,4 @@
from typing import Optional, Dict
from typing import Optional, Dict, Any
from timeit import default_timer as timer
import srsly
import tqdm
@ -8,8 +8,9 @@ import thinc
import thinc.schedules
from thinc.api import use_pytorch_for_gpu_memory, require_gpu, fix_random_seed
import random
import typer
from ._app import app, Arg, Opt
from ._util import app, Arg, Opt, parse_config_overrides
from ..gold import Corpus, Example
from ..lookups import Lookups
from .. import util
@ -24,9 +25,12 @@ from ..ml import models # noqa: F401
registry = util.registry
@app.command("train")
@app.command(
"train", context_settings={"allow_extra_args": True, "ignore_unknown_options": True}
)
def train_cli(
# fmt: off
ctx: typer.Context, # This is only used to read additional arguments
train_path: Path = Arg(..., help="Location of training data", exists=True),
dev_path: Path = Arg(..., help="Location of development data", exists=True),
config_path: Path = Arg(..., help="Path to config file", exists=True),
@ -36,17 +40,36 @@ def train_cli(
# fmt: on
):
"""
Train or update a spaCy model. Requires data to be formatted in spaCy's
JSON format. To convert data from other formats, use the `spacy convert`
command.
Train or update a spaCy model. Requires data in spaCy's binary format. To
convert data from other formats, use the `spacy convert` command. The
config file includes all settings and hyperparameters used during traing.
To override settings in the config, e.g. settings that point to local
paths or that you want to experiment with, you can override them as
command line options. For instance, --training.batch_size 128 overrides
the value of "batch_size" in the block "[training]". The --code argument
lets you pass in a Python file that's imported before training. It can be
used to register custom functions and architectures that can then be
referenced in the config.
"""
util.set_env_log(verbose)
verify_cli_args(**locals())
try:
util.import_file("python_code", code_path)
except Exception as e:
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
train(config_path, {"train": train_path, "dev": dev_path}, output_path=output_path)
verify_cli_args(
train_path=train_path,
dev_path=dev_path,
config_path=config_path,
code_path=code_path,
)
overrides = parse_config_overrides(ctx.args)
if code_path is not None:
try:
util.import_file("python_code", code_path)
except Exception as e:
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
train(
config_path,
{"train": train_path, "dev": dev_path},
output_path=output_path,
config_overrides=overrides,
)
def train(
@ -54,7 +77,7 @@ def train(
data_paths: Dict[str, Path],
raw_text: Optional[Path] = None,
output_path: Optional[Path] = None,
weights_data: Optional[bytes] = None,
config_overrides: Dict[str, Any] = {},
) -> None:
msg.info(f"Loading config from: {config_path}")
# Read the config first without creating objects, to get to the original nlp_config
@ -469,7 +492,6 @@ def verify_cli_args(
config_path: Path,
output_path: Optional[Path] = None,
code_path: Optional[Path] = None,
verbose: bool = False,
):
# Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists():

View File

@ -4,7 +4,7 @@ import sys
import requests
from wasabi import msg, Printer
from ._app import app
from ._util import app
from .. import about
from ..util import get_package_version, get_installed_models, get_base_version
from ..util import get_package_path, get_model_meta, is_compatible_version

View File

@ -1,5 +1,5 @@
import pytest
from spacy.cli.project.util import validate_project_commands
from spacy.cli._util import validate_project_commands
from spacy.schemas import ProjectConfigSchema, validate