Support config overrides via environment variables

This commit is contained in:
Ines Montani 2020-09-21 11:25:10 +02:00
parent 1114219ae3
commit 5497acf49a
2 changed files with 59 additions and 15 deletions

View File

@ -11,9 +11,10 @@ from typer.main import get_command
from contextlib import contextmanager from contextlib import contextmanager
from thinc.config import Config, ConfigValidationError from thinc.config import Config, ConfigValidationError
from configparser import InterpolationError from configparser import InterpolationError
import os
from ..schemas import ProjectConfigSchema, validate from ..schemas import ProjectConfigSchema, validate
from ..util import import_file, run_command, make_tempdir, registry from ..util import import_file, run_command, make_tempdir, registry, logger
if TYPE_CHECKING: if TYPE_CHECKING:
from pathy import Pathy # noqa: F401 from pathy import Pathy # noqa: F401
@ -61,16 +62,38 @@ def setup_cli() -> None:
command(prog_name=COMMAND) command(prog_name=COMMAND)
def parse_config_overrides(args: List[str]) -> Dict[str, Any]: def parse_config_env_overrides(
*, prefix: str = "SPACY_CONFIG_", dot: str = "__"
) -> Dict[str, Any]:
"""Generate a dictionary of config overrides based on environment variables,
e.g. SPACY_CONFIG_TRAINING__BATCH_SIZE=123 overrides the training.batch_size
setting.
prefix (str): The env variable prefix for config overrides.
dot (str): String used to represent the "dot", e.g. in training.batch_size.
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
"""
result = {}
for env_key, value in os.environ.items():
if env_key.startswith(prefix):
opt = env_key[len(prefix) :].lower().replace(dot, ".")
if "." in opt:
result[opt] = try_json_loads(value)
return result
def parse_config_overrides(args: List[str], env_vars: bool = True) -> Dict[str, Any]:
"""Generate a dictionary of config overrides based on the extra arguments """Generate a dictionary of config overrides based on the extra arguments
provided on the CLI, e.g. --training.batch_size to override provided on the CLI, e.g. --training.batch_size to override
"training.batch_size". Arguments without a "." are considered invalid, "training.batch_size". Arguments without a "." are considered invalid,
since the config only allows top-level sections to exist. since the config only allows top-level sections to exist.
args (List[str]): The extra arguments from the command line. args (List[str]): The extra arguments from the command line.
env_vars (bool): Include environment variables.
RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting. RETURNS (Dict[str, Any]): The parsed dict, keyed by nested config setting.
""" """
result = {} env_overrides = parse_config_env_overrides() if env_vars else {}
cli_overrides = {}
while args: while args:
opt = args.pop(0) opt = args.pop(0)
err = f"Invalid CLI argument '{opt}'" err = f"Invalid CLI argument '{opt}'"
@ -87,18 +110,27 @@ def parse_config_overrides(args: List[str]) -> Dict[str, Any]:
value = "true" value = "true"
else: else:
value = args.pop(0) value = args.pop(0)
# Just like we do in the config, we're calling json.loads on the if opt not in env_overrides:
# values. But since they come from the CLI, it'd be unintuitive to cli_overrides[opt] = try_json_loads(value)
# explicitly mark strings with escaped quotes. So we're working
# around that here by falling back to a string if parsing fails.
# TODO: improve logic to handle simple types like list of strings?
try:
result[opt] = srsly.json_loads(value)
except ValueError:
result[opt] = str(value)
else: else:
msg.fail(f"{err}: override option should start with --", exits=1) msg.fail(f"{err}: override option should start with --", exits=1)
return result if cli_overrides:
logger.debug(f"Config overrides from CLI: {list(cli_overrides)}")
if env_overrides:
logger.debug(f"Config overrides from env variables: {list(env_overrides)}")
return {**cli_overrides, **env_overrides}
def try_json_loads(value: Any) -> Any:
# Just like we do in the config, we're calling json.loads on the
# values. But since they come from the CLI, it'd be unintuitive to
# explicitly mark strings with escaped quotes. So we're working
# around that here by falling back to a string if parsing fails.
# TODO: improve logic to handle simple types like list of strings?
try:
return srsly.json_loads(value)
except ValueError:
return str(value)
def load_project_config(path: Path, interpolate: bool = True) -> Dict[str, Any]: def load_project_config(path: Path, interpolate: bool = True) -> Dict[str, Any]:

View File

@ -1,15 +1,15 @@
import pytest import pytest
from click import NoSuchOption from click import NoSuchOption
from spacy.training import docs_to_json, biluo_tags_from_offsets from spacy.training import docs_to_json, biluo_tags_from_offsets
from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs from spacy.training.converters import iob2docs, conll_ner2docs, conllu2docs
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import load_project_config, substitute_project_variables
from spacy.cli._util import string_to_list from spacy.cli._util import string_to_list, parse_config_env_overrides
from thinc.config import ConfigValidationError from thinc.config import ConfigValidationError
import srsly import srsly
import os
from .util import make_tempdir from .util import make_tempdir
@ -341,6 +341,18 @@ def test_parse_config_overrides_invalid_2(args):
parse_config_overrides(args) parse_config_overrides(args)
def test_parse_cli_overrides():
prefix = "SPACY_CONFIG_"
dot = "__"
os.environ[f"{prefix}TRAINING{dot}BATCH_SIZE"] = "123"
os.environ[f"{prefix}FOO{dot}BAR{dot}BAZ"] = "hello"
os.environ[prefix] = "bad"
result = parse_config_env_overrides(prefix=prefix, dot=dot)
assert len(result) == 2
assert result["training.batch_size"] == 123
assert result["foo.bar.baz"] == "hello"
@pytest.mark.parametrize("lang", ["en", "nl"]) @pytest.mark.parametrize("lang", ["en", "nl"])
@pytest.mark.parametrize( @pytest.mark.parametrize(
"pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]] "pipeline", [["tagger", "parser", "ner"], [], ["ner", "textcat", "sentencizer"]]