Support config overrides via environment variables

This commit is contained in:
Ines Montani 2020-09-21 11:25:10 +02:00
parent 1114219ae3
commit 5497acf49a
2 changed files with 59 additions and 15 deletions

View File

@ -11,9 +11,10 @@ from typer.main import get_command
from contextlib import contextmanager
from thinc.config import Config, ConfigValidationError
from configparser import InterpolationError
import os
from ..schemas import ProjectConfigSchema, validate
from ..util import import_file, run_command, make_tempdir, registry
from ..util import import_file, run_command, make_tempdir, registry, logger
if TYPE_CHECKING:
from pathy import Pathy # noqa: F401
@ -61,16 +62,38 @@ def setup_cli() -> None:
command(prog_name=COMMAND)
def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
def parse_config_env_overrides(
*, prefix: str = "SPACY_CONFIG_", dot: str = "__"
) -> Dict[str, Any]:
"""Generate a dictionary of config overrides based on environment variables,
e.g. SPACY_CONFIG_TRAINING__BATCH_SIZE=123 overrides the training.batch_size
setting.
prefix (str): The env variable prefix for config overrides.
dot (str): String used to represent the "dot", e.g. in training.batch_size.
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
"""
result = {}
for env_key, value in os.environ.items():
if env_key.startswith(prefix):
opt = env_key[len(prefix) :].lower().replace(dot, ".")
if "." in opt:
result[opt] = try_json_loads(value)
return result
def parse_config_overrides(args: List[str], env_vars: bool = True) -> Dict[str, Any]:
"""Generate a dictionary of config overrides based on the extra arguments
provided on the CLI, e.g. --training.batch_size to override
"training.batch_size". Arguments without a "." are considered invalid,
since the config only allows top-level sections to exist.
args (List[str]): The extra arguments from the command line.
env_vars (bool): Include environment variables.
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
"""
result = {}
env_overrides = parse_config_env_overrides() if env_vars else {}
cli_overrides = {}
while args:
opt = args.pop(0)
err = f"Invalid CLI argument '{opt}'"
@ -87,18 +110,27 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
value = "true"
else:
value = args.pop(0)
if opt not in env_overrides:
cli_overrides[opt] = try_json_loads(value)
else:
msg.fail(f"{err}: override option should start with --", exits=1)
if cli_overrides:
logger.debug(f"Config overrides from CLI: {list(cli_overrides)}")
if env_overrides:
logger.debug(f"Config overrides from env variables: {list(env_overrides)}")
return {**cli_overrides, **env_overrides}
def try_json_loads(value: Any) -> Any:
# Just like we do in the config, we're calling json.loads on the
# values. But since they come from the CLI, it'd be unintuitive to
# explicitly mark strings with escaped quotes. So we're working
# around that here by falling back to a string if parsing fails.
# TODO: improve logic to handle simple types like list of strings?
try:
result[opt] = srsly.json_loads(value)
return srsly.json_loads(value)
except ValueError:
result[opt] = str(value)
else:
msg.fail(f"{err}: override option should start with --", exits=1)
return result
return str(value)
def load_project_config(path: Path, interpolate: bool = True) -> Dict[str, Any]:

View File

@ -1,15 +1,15 @@
import pytest
from click import NoSuchOption
from spacy.training import docs_to_json, biluo_tags_from_offsets
from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables
from spacy.cli._util import string_to_list
from spacy.cli._util import string_to_list, parse_config_env_overrides
from thinc.config import ConfigValidationError
import srsly
import os
from .util import make_tempdir
@ -341,6 +341,18 @@ def test_parse_config_overrides_invalid_2(args):
parse_config_overrides(args)
def test_parse_cli_overrides():
prefix = "SPACY_CONFIG_"
dot = "__"
os.environ[f"{prefix}TRAINING{dot}BATCH_SIZE"] = "123"
os.environ[f"{prefix}FOO{dot}BAR{dot}BAZ"] = "hello"
os.environ[prefix] = "bad"
result = parse_config_env_overrides(prefix=prefix, dot=dot)
assert len(result) == 2
assert result["training.batch_size"] == 123
assert result["foo.bar.baz"] == "hello"
@pytest.mark.parametrize("lang", ["en", "nl"])
@pytest.mark.parametrize(
"pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]