diff --git a/spacy/cli/__init__.py b/spacy/cli/__init__.py index 0568b34de..eedd8961f 100644 --- a/spacy/cli/__init__.py +++ b/spacy/cli/__init__.py @@ -1,6 +1,6 @@ from wasabi import msg -from ._app import app, setup_cli # noqa: F401 +from ._util import app, setup_cli # noqa: F401 # These are the actual functions, NOT the wrapped CLI commands. The CLI commands # are registered automatically and won't have to be imported here. diff --git a/spacy/cli/_app.py b/spacy/cli/_app.py deleted file mode 100644 index e970c4dde..000000000 --- a/spacy/cli/_app.py +++ /dev/null @@ -1,31 +0,0 @@ -import typer -from typer.main import get_command - - -COMMAND = "python -m spacy" -NAME = "spacy" -HELP = """spaCy Command-line Interface - -DOCS: https://spacy.io/api/cli -""" -PROJECT_HELP = f"""Command-line interface for spaCy projects and working with -project templates. You'd typically start by cloning a project template to a local -directory and fetching its assets like datasets etc. See the project's -project.yml for the available commands. -""" - - -app = typer.Typer(name=NAME, help=HELP) -project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True) -app.add_typer(project_cli) - -# Wrappers for Typer's annotations. Initially created to set defaults and to -# keep the names short, but not needed at the moment. -Arg = typer.Argument -Opt = typer.Option - - -def setup_cli() -> None: - # Ensure that the help messages always display the correct prompt - command = get_command(app) - command(prog_name=COMMAND) diff --git a/spacy/cli/project/util.py b/spacy/cli/_util.py similarity index 57% rename from spacy/cli/project/util.py rename to spacy/cli/_util.py index 1111ddc2d..4289df856 100644 --- a/spacy/cli/project/util.py +++ b/spacy/cli/_util.py @@ -1,14 +1,76 @@ -from typing import Dict, Any, Union +from typing import Dict, Any, Union, List from pathlib import Path from wasabi import msg import srsly import hashlib +import typer +from typer.main import get_command -from ...schemas import ProjectConfigSchema, validate +from ..schemas import ProjectConfigSchema, validate PROJECT_FILE = "project.yml" PROJECT_LOCK = "project.lock" +COMMAND = "python -m spacy" +NAME = "spacy" +HELP = """spaCy Command-line Interface + +DOCS: https://spacy.io/api/cli +""" +PROJECT_HELP = f"""Command-line interface for spaCy projects and working with +project templates. You'd typically start by cloning a project template to a local +directory and fetching its assets like datasets etc. See the project's +{PROJECT_FILE} for the available commands. +""" + +# Wrappers for Typer's annotations. Initially created to set defaults and to +# keep the names short, but not needed at the moment. +Arg = typer.Argument +Opt = typer.Option + +app = typer.Typer(name=NAME, help=HELP) +project_cli = typer.Typer(name="project", help=PROJECT_HELP, no_args_is_help=True) +app.add_typer(project_cli) + + +def setup_cli() -> None: + # Ensure that the help messages always display the correct prompt + command = get_command(app) + command(prog_name=COMMAND) + + +def parse_config_overrides(args: List[str]) -> Dict[str, Any]: + """Generate a dictionary of config overrides based on the extra arguments + provided on the CLI, e.g. --training.batch_size to override + "training.batch_size". Arguments without a "." are considered invalid, + since the config only allows top-level sections to exist. + + args (List[str]): The extra arguments from the command line. + RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. + """ + result = {} + while args: + opt = args.pop(0) + err = f"Invalid config override '{opt}'" + if opt.startswith("--"): # new argument + opt = opt.replace("--", "") + if "." not in opt: + msg.fail(f"{err}: can't override top-level section", exits=1) + if not args or args[0].startswith("--"): # flag with no value + value = True + else: + value = args.pop(0) + # Just like we do in the config, we're calling json.loads on the + # values. But since they come from the CLI, it'd b unintuitive to + # explicitly mark strings with escaped quotes. So we're working + # around that here by falling back to a string if parsing fails. + try: + result[opt] = srsly.json_loads(value) + except ValueError: + result[opt] = str(value) + else: + msg.fail(f"{err}: options need to start with --", exits=1) + return result def load_project_config(path: Path) -> Dict[str, Any]: diff --git a/spacy/cli/convert.py b/spacy/cli/convert.py index c26b5ee75..3c04ca15c 100644 --- a/spacy/cli/convert.py +++ b/spacy/cli/convert.py @@ -6,7 +6,7 @@ import srsly import re import sys -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from ..gold import docs_to_json from ..tokens import DocBin from ..gold.converters import iob2docs, conll_ner2docs, json2docs, conllu2docs diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 712bc7914..df3236511 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -5,7 +5,7 @@ import sys import srsly from wasabi import Printer, MESSAGES -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from ..gold import Corpus, Example from ..syntax import nonproj from ..language import Language diff --git a/spacy/cli/download.py b/spacy/cli/download.py index f192cb196..cdbd7514a 100644 --- a/spacy/cli/download.py +++ b/spacy/cli/download.py @@ -4,7 +4,7 @@ import sys from wasabi import msg import typer -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from .. import about from ..util import is_package, get_base_version, run_command diff --git a/spacy/cli/evaluate.py b/spacy/cli/evaluate.py index a5d4a3661..3467204b9 100644 --- a/spacy/cli/evaluate.py +++ b/spacy/cli/evaluate.py @@ -8,7 +8,7 @@ from thinc.api import require_gpu, fix_random_seed from ..gold import Corpus from ..tokens import Doc -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from ..scorer import Scorer from .. import util from .. import displacy diff --git a/spacy/cli/info.py b/spacy/cli/info.py index 9f1ec3855..98a1efeb8 100644 --- a/spacy/cli/info.py +++ b/spacy/cli/info.py @@ -4,7 +4,7 @@ from pathlib import Path from wasabi import Printer import srsly -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from .. import util from .. import about diff --git a/spacy/cli/init_model.py b/spacy/cli/init_model.py index 5cfde43e0..e8c17ae33 100644 --- a/spacy/cli/init_model.py +++ b/spacy/cli/init_model.py @@ -12,7 +12,7 @@ import srsly import warnings from wasabi import Printer -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from ..vectors import Vectors from ..errors import Errors, Warnings from ..language import Language diff --git a/spacy/cli/package.py b/spacy/cli/package.py index dbc485848..74f9b0c96 100644 --- a/spacy/cli/package.py +++ b/spacy/cli/package.py @@ -5,7 +5,7 @@ from wasabi import Printer, get_raw_input import srsly import sys -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from ..schemas import validate, ModelMetaSchema from .. import util from .. import about diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 58e82028b..d1fa71fd7 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -12,7 +12,7 @@ from wasabi import msg import srsly from functools import partial -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from ..errors import Errors from ..ml.models.multi_task import build_cloze_multi_task_model from ..ml.models.multi_task import build_cloze_characters_multi_task_model diff --git a/spacy/cli/profile.py b/spacy/cli/profile.py index 3dc9f1027..2faf002d0 100644 --- a/spacy/cli/profile.py +++ b/spacy/cli/profile.py @@ -8,7 +8,7 @@ import sys import itertools from wasabi import msg, Printer -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt from ..language import Language from ..util import load_model diff --git a/spacy/cli/project/assets.py b/spacy/cli/project/assets.py index 2270574ab..2b7dbaf66 100644 --- a/spacy/cli/project/assets.py +++ b/spacy/cli/project/assets.py @@ -7,8 +7,7 @@ import re import shutil from ...util import ensure_path, working_dir -from .._app import project_cli, Arg -from .util import PROJECT_FILE, load_project_config, get_checksum +from .._util import project_cli, Arg, PROJECT_FILE, load_project_config, get_checksum # TODO: find a solution for caches diff --git a/spacy/cli/project/clone.py b/spacy/cli/project/clone.py index 6ce2d309e..bb9ba74cb 100644 --- a/spacy/cli/project/clone.py +++ b/spacy/cli/project/clone.py @@ -7,8 +7,7 @@ import re from ... import about from ...util import ensure_path, run_command, make_tempdir -from .._app import project_cli, Arg, Opt, COMMAND -from .util import PROJECT_FILE +from .._util import project_cli, Arg, Opt, COMMAND, PROJECT_FILE @project_cli.command("clone") diff --git a/spacy/cli/project/dvc.py b/spacy/cli/project/dvc.py index c29618820..7386339d9 100644 --- a/spacy/cli/project/dvc.py +++ b/spacy/cli/project/dvc.py @@ -5,8 +5,8 @@ import subprocess from pathlib import Path from wasabi import msg -from .util import PROJECT_FILE, load_project_config, get_hash -from .._app import project_cli, Arg, Opt, NAME, COMMAND +from .._util import PROJECT_FILE, load_project_config, get_hash, project_cli +from .._util import Arg, Opt, NAME, COMMAND from ...util import working_dir, split_command, join_command, run_command diff --git a/spacy/cli/project/run.py b/spacy/cli/project/run.py index a8cc58c01..5c66095aa 100644 --- a/spacy/cli/project/run.py +++ b/spacy/cli/project/run.py @@ -5,9 +5,8 @@ import sys import srsly from ...util import working_dir, run_command, split_command, is_cwd, join_command -from .._app import project_cli, Arg, Opt, COMMAND -from .util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash -from .util import get_checksum +from .._util import PROJECT_FILE, PROJECT_LOCK, load_project_config, get_hash +from .._util import get_checksum, project_cli, Arg, Opt, COMMAND @project_cli.command("run") diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 45520978b..32d1a456e 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -1,4 +1,4 @@ -from typing import Optional, Dict +from typing import Optional, Dict, Any from timeit import default_timer as timer import srsly import tqdm @@ -8,8 +8,9 @@ import thinc import thinc.schedules from thinc.api import use_pytorch_for_gpu_memory, require_gpu, fix_random_seed import random +import typer -from ._app import app, Arg, Opt +from ._util import app, Arg, Opt, parse_config_overrides from ..gold import Corpus, Example from ..lookups import Lookups from .. import util @@ -24,9 +25,12 @@ from ..ml import models # noqa: F401 registry = util.registry -@app.command("train") +@app.command( + "train", context_settings={"allow_extra_args": True, "ignore_unknown_options": True} +) def train_cli( # fmt: off + ctx: typer.Context, # This is only used to read additional arguments train_path: Path = Arg(..., help="Location of training data", exists=True), dev_path: Path = Arg(..., help="Location of development data", exists=True), config_path: Path = Arg(..., help="Path to config file", exists=True), @@ -36,17 +40,36 @@ def train_cli( # fmt: on ): """ - Train or update a spaCy model. Requires data to be formatted in spaCy's - JSON format. To convert data from other formats, use the `spacy convert` - command. + Train or update a spaCy model. Requires data in spaCy's binary format. To + convert data from other formats, use the `spacy convert` command. The + config file includes all settings and hyperparameters used during traing. + To override settings in the config, e.g. settings that point to local + paths or that you want to experiment with, you can override them as + command line options. For instance, --training.batch_size 128 overrides + the value of "batch_size" in the block "[training]". The --code argument + lets you pass in a Python file that's imported before training. It can be + used to register custom functions and architectures that can then be + referenced in the config. """ util.set_env_log(verbose) - verify_cli_args(**locals()) - try: - util.import_file("python_code", code_path) - except Exception as e: - msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) - train(config_path, {"train": train_path, "dev": dev_path}, output_path=output_path) + verify_cli_args( + train_path=train_path, + dev_path=dev_path, + config_path=config_path, + code_path=code_path, + ) + overrides = parse_config_overrides(ctx.args) + if code_path is not None: + try: + util.import_file("python_code", code_path) + except Exception as e: + msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) + train( + config_path, + {"train": train_path, "dev": dev_path}, + output_path=output_path, + config_overrides=overrides, + ) def train( @@ -54,7 +77,7 @@ def train( data_paths: Dict[str, Path], raw_text: Optional[Path] = None, output_path: Optional[Path] = None, - weights_data: Optional[bytes] = None, + config_overrides: Dict[str, Any] = {}, ) -> None: msg.info(f"Loading config from: {config_path}") # Read the config first without creating objects, to get to the original nlp_config @@ -469,7 +492,6 @@ def verify_cli_args( config_path: Path, output_path: Optional[Path] = None, code_path: Optional[Path] = None, - verbose: bool = False, ): # Make sure all files and paths exists if they are needed if not config_path or not config_path.exists(): diff --git a/spacy/cli/validate.py b/spacy/cli/validate.py index 4271817f1..0580d34c5 100644 --- a/spacy/cli/validate.py +++ b/spacy/cli/validate.py @@ -4,7 +4,7 @@ import sys import requests from wasabi import msg, Printer -from ._app import app +from ._util import app from .. import about from ..util import get_package_version, get_installed_models, get_base_version from ..util import get_package_path, get_model_meta, is_compatible_version diff --git a/spacy/tests/test_projects.py b/spacy/tests/test_projects.py index c3477f463..65ac5739d 100644 --- a/spacy/tests/test_projects.py +++ b/spacy/tests/test_projects.py @@ -1,5 +1,5 @@ import pytest -from spacy.cli.project.util import validate_project_commands +from spacy.cli._util import validate_project_commands from spacy.schemas import ProjectConfigSchema, validate