diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index ff11f97f6..81ac6c9fc 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -45,14 +45,16 @@ def init_config_cli( if isinstance(optimize, Optimizations): # instance of enum from the CLI optimize = optimize.value pipeline = string_to_list(pipeline) - init_config( - output_file, + is_stdout = str(output_file) == "-" + config = init_config( lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu, pretraining=pretraining, + silent=is_stdout, ) + save_config(config, output_file, is_stdout=is_stdout) @init_cli.command("fill-config") @@ -118,16 +120,15 @@ def fill_config( def init_config( - output_file: Path, *, lang: str, pipeline: List[str], optimize: str, cpu: bool, pretraining: bool = False, -) -> None: - is_stdout = str(output_file) == "-" - msg = Printer(no_print=is_stdout) + silent: bool = True, +) -> Config: + msg = Printer(no_print=silent) with TEMPLATE_PATH.open("r") as f: template = Template(f.read()) # Filter out duplicates since tok2vec and transformer are added by template @@ -173,7 +174,7 @@ def init_config( pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) config = pretrain_config.merge(config) msg.good("Auto-filled config with all values") - save_config(config, output_file, is_stdout=is_stdout) + return config def save_config( diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index 62584d0ce..06c7a3a90 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -8,7 +8,7 @@ from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import string_to_list -from thinc.api import ConfigValidationError +from thinc.api import ConfigValidationError, Config import srsly import os @@ -368,7 +368,8 @@ def test_parse_cli_overrides(): @pytest.mark.parametrize("optimize", ["efficiency", "accuracy"]) def test_init_config(lang, pipeline, optimize): # TODO: add more tests and also check for GPU with transformers - init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True) + config = init_config(lang=lang, pipeline=pipeline, optimize=optimize, cpu=True) + assert isinstance(config, Config) def test_model_recommendations():