mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Sync overrides with CLI overrides
This commit is contained in:
parent
5497acf49a
commit
758ead8a47
|
@ -7,6 +7,7 @@ import srsly
|
|||
import hashlib
|
||||
import typer
|
||||
from click import NoSuchOption
|
||||
from click.parser import split_arg_string
|
||||
from typer.main import get_command
|
||||
from contextlib import contextmanager
|
||||
from thinc.config import Config, ConfigValidationError
|
||||
|
@ -38,6 +39,7 @@ commands to check and validate your config files, training and evaluation data,
|
|||
and custom model implementations.
|
||||
"""
|
||||
INIT_HELP = """Commands for initializing configs and pipeline packages."""
|
||||
OVERRIDES_ENV_VAR = "SPACY_CONFIG_OVERRIDES"
|
||||
|
||||
# Wrappers for Typer's annotations. Initially created to set defaults and to
|
||||
# keep the names short, but not needed at the moment.
|
||||
|
@ -62,46 +64,41 @@ def setup_cli() -> None:
|
|||
command(prog_name=COMMAND)
|
||||
|
||||
|
||||
def parse_config_env_overrides(
|
||||
*, prefix: str = "SPACY_CONFIG_", dot: str = "__"
|
||||
def parse_config_overrides(
|
||||
args: List[str], env_var: Optional[str] = OVERRIDES_ENV_VAR
|
||||
) -> Dict[str, Any]:
|
||||
"""Generate a dictionary of config overrides based on environment variables,
|
||||
e.g. SPACY_CONFIG_TRAINING__BATCH_SIZE=123 overrides the training.batch_size
|
||||
setting.
|
||||
|
||||
prefix (str): The env variable prefix for config overrides.
|
||||
dot (str): String used to represent the "dot", e.g. in training.batch_size.
|
||||
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
|
||||
"""
|
||||
result = {}
|
||||
for env_key, value in os.environ.items():
|
||||
if env_key.startswith(prefix):
|
||||
opt = env_key[len(prefix) :].lower().replace(dot, ".")
|
||||
if "." in opt:
|
||||
result[opt] = try_json_loads(value)
|
||||
return result
|
||||
|
||||
|
||||
def parse_config_overrides(args: List[str], env_vars: bool = True) -> Dict[str, Any]:
|
||||
"""Generate a dictionary of config overrides based on the extra arguments
|
||||
provided on the CLI, e.g. --training.batch_size to override
|
||||
"training.batch_size". Arguments without a "." are considered invalid,
|
||||
since the config only allows top-level sections to exist.
|
||||
|
||||
args (List[str]): The extra arguments from the command line.
|
||||
env_vars (bool): Include environment variables.
|
||||
env_vars (Optional[str]): Optional environment variable to read from.
|
||||
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
|
||||
"""
|
||||
env_overrides = parse_config_env_overrides() if env_vars else {}
|
||||
cli_overrides = {}
|
||||
env_string = os.environ.get(env_var, "") if env_var else ""
|
||||
env_overrides = _parse_overrides(split_arg_string(env_string))
|
||||
cli_overrides = _parse_overrides(args, is_cli=True)
|
||||
if cli_overrides:
|
||||
keys = [k for k in cli_overrides if k not in env_overrides]
|
||||
logger.debug(f"Config overrides from CLI: {keys}")
|
||||
if env_overrides:
|
||||
logger.debug(f"Config overrides from env variables: {list(env_overrides)}")
|
||||
return {**cli_overrides, **env_overrides}
|
||||
|
||||
|
||||
def _parse_overrides(args: List[str], is_cli: bool = False) -> Dict[str, Any]:
|
||||
result = {}
|
||||
while args:
|
||||
opt = args.pop(0)
|
||||
err = f"Invalid CLI argument '{opt}'"
|
||||
err = f"Invalid config override '{opt}'"
|
||||
if opt.startswith("--"): # new argument
|
||||
orig_opt = opt
|
||||
opt = opt.replace("--", "")
|
||||
if "." not in opt:
|
||||
raise NoSuchOption(orig_opt)
|
||||
if is_cli:
|
||||
raise NoSuchOption(orig_opt)
|
||||
else:
|
||||
msg.fail(f"{err}: can't override top-level sections", exits=1)
|
||||
if "=" in opt: # we have --opt=value
|
||||
opt, value = opt.split("=", 1)
|
||||
opt = opt.replace("-", "_")
|
||||
|
@ -110,27 +107,18 @@ def parse_config_overrides(args: List[str], env_vars: bool = True) -> Dict[str,
|
|||
value = "true"
|
||||
else:
|
||||
value = args.pop(0)
|
||||
if opt not in env_overrides:
|
||||
cli_overrides[opt] = try_json_loads(value)
|
||||
# Just like we do in the config, we're calling json.loads on the
|
||||
# values. But since they come from the CLI, it'd be unintuitive to
|
||||
# explicitly mark strings with escaped quotes. So we're working
|
||||
# around that here by falling back to a string if parsing fails.
|
||||
# TODO: improve logic to handle simple types like list of strings?
|
||||
try:
|
||||
result[opt] = srsly.json_loads(value)
|
||||
except ValueError:
|
||||
result[opt] = str(value)
|
||||
else:
|
||||
msg.fail(f"{err}: override option should start with --", exits=1)
|
||||
if cli_overrides:
|
||||
logger.debug(f"Config overrides from CLI: {list(cli_overrides)}")
|
||||
if env_overrides:
|
||||
logger.debug(f"Config overrides from env variables: {list(env_overrides)}")
|
||||
return {**cli_overrides, **env_overrides}
|
||||
|
||||
|
||||
def try_json_loads(value: Any) -> Any:
|
||||
# Just like we do in the config, we're calling json.loads on the
|
||||
# values. But since they come from the CLI, it'd be unintuitive to
|
||||
# explicitly mark strings with escaped quotes. So we're working
|
||||
# around that here by falling back to a string if parsing fails.
|
||||
# TODO: improve logic to handle simple types like list of strings?
|
||||
try:
|
||||
return srsly.json_loads(value)
|
||||
except ValueError:
|
||||
return str(value)
|
||||
msg.fail(f"{err}: name should start with --", exits=1)
|
||||
return result
|
||||
|
||||
|
||||
def load_project_config(path: Path, interpolate: bool = True) -> Dict[str, Any]:
|
||||
|
|
|
@ -6,7 +6,7 @@ from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
|||
from spacy.cli.init_config import init_config, RECOMMENDATIONS
|
||||
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
||||
from spacy.cli._util import load_project_config, substitute_project_variables
|
||||
from spacy.cli._util import string_to_list, parse_config_env_overrides
|
||||
from spacy.cli._util import string_to_list, OVERRIDES_ENV_VAR
|
||||
from thinc.config import ConfigValidationError
|
||||
import srsly
|
||||
import os
|
||||
|
@ -342,15 +342,21 @@ def test_parse_config_overrides_invalid_2(args):
|
|||
|
||||
|
||||
def test_parse_cli_overrides():
|
||||
prefix = "SPACY_CONFIG_"
|
||||
dot = "__"
|
||||
os.environ[f"{prefix}TRAINING{dot}BATCH_SIZE"] = "123"
|
||||
os.environ[f"{prefix}FOO{dot}BAR{dot}BAZ"] = "hello"
|
||||
os.environ[prefix] = "bad"
|
||||
result = parse_config_env_overrides(prefix=prefix, dot=dot)
|
||||
assert len(result) == 2
|
||||
assert result["training.batch_size"] == 123
|
||||
assert result["foo.bar.baz"] == "hello"
|
||||
os.environ[OVERRIDES_ENV_VAR] = "--x.foo bar --x.bar=12 --x.baz false --y.foo=hello"
|
||||
result = parse_config_overrides([])
|
||||
assert len(result) == 4
|
||||
assert result["x.foo"] == "bar"
|
||||
assert result["x.bar"] == 12
|
||||
assert result["x.baz"] is False
|
||||
assert result["y.foo"] == "hello"
|
||||
os.environ[OVERRIDES_ENV_VAR] = "--x"
|
||||
assert parse_config_overrides([], env_var=None) == {}
|
||||
with pytest.raises(SystemExit):
|
||||
parse_config_overrides([])
|
||||
os.environ[OVERRIDES_ENV_VAR] = "hello world"
|
||||
with pytest.raises(SystemExit):
|
||||
parse_config_overrides([])
|
||||
del os.environ[OVERRIDES_ENV_VAR]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("lang", ["en", "nl"])
|
||||
|
|
Loading…
Reference in New Issue
Block a user