mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
returning config in init_config
This commit is contained in:
parent
8921364579
commit
8f8a7f1733
|
@ -45,7 +45,7 @@ def init_config_cli(
|
||||||
if isinstance(optimize, Optimizations): # instance of enum from the CLI
|
if isinstance(optimize, Optimizations): # instance of enum from the CLI
|
||||||
optimize = optimize.value
|
optimize = optimize.value
|
||||||
pipeline = string_to_list(pipeline)
|
pipeline = string_to_list(pipeline)
|
||||||
init_config(
|
config = init_config(
|
||||||
output_file,
|
output_file,
|
||||||
lang=lang,
|
lang=lang,
|
||||||
pipeline=pipeline,
|
pipeline=pipeline,
|
||||||
|
@ -53,6 +53,8 @@ def init_config_cli(
|
||||||
cpu=cpu,
|
cpu=cpu,
|
||||||
pretraining=pretraining,
|
pretraining=pretraining,
|
||||||
)
|
)
|
||||||
|
is_stdout = str(output_file) == "-"
|
||||||
|
save_config(config, output_file, is_stdout=is_stdout)
|
||||||
|
|
||||||
|
|
||||||
@init_cli.command("fill-config")
|
@init_cli.command("fill-config")
|
||||||
|
@ -125,7 +127,7 @@ def init_config(
|
||||||
optimize: str,
|
optimize: str,
|
||||||
cpu: bool,
|
cpu: bool,
|
||||||
pretraining: bool = False,
|
pretraining: bool = False,
|
||||||
) -> None:
|
) -> Config:
|
||||||
is_stdout = str(output_file) == "-"
|
is_stdout = str(output_file) == "-"
|
||||||
msg = Printer(no_print=is_stdout)
|
msg = Printer(no_print=is_stdout)
|
||||||
with TEMPLATE_PATH.open("r") as f:
|
with TEMPLATE_PATH.open("r") as f:
|
||||||
|
@ -173,7 +175,7 @@ def init_config(
|
||||||
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
|
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
|
||||||
config = pretrain_config.merge(config)
|
config = pretrain_config.merge(config)
|
||||||
msg.good("Auto-filled config with all values")
|
msg.good("Auto-filled config with all values")
|
||||||
save_config(config, output_file, is_stdout=is_stdout)
|
return config
|
||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(
|
||||||
|
|
|
@ -8,7 +8,7 @@ from spacy.cli.init_config import init_config, RECOMMENDATIONS
|
||||||
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
||||||
from spacy.cli._util import load_project_config, substitute_project_variables
|
from spacy.cli._util import load_project_config, substitute_project_variables
|
||||||
from spacy.cli._util import string_to_list
|
from spacy.cli._util import string_to_list
|
||||||
from thinc.api import ConfigValidationError
|
from thinc.api import ConfigValidationError, Config
|
||||||
import srsly
|
import srsly
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -368,7 +368,8 @@ def test_parse_cli_overrides():
|
||||||
@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
|
@pytest.mark.parametrize("optimize", ["efficiency", "accuracy"])
|
||||||
def test_init_config(lang, pipeline, optimize):
|
def test_init_config(lang, pipeline, optimize):
|
||||||
# TODO: add more tests and also check for GPU with transformers
|
# TODO: add more tests and also check for GPU with transformers
|
||||||
init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True)
|
config = init_config("-", lang=lang, pipeline=pipeline, optimize=optimize, cpu=True)
|
||||||
|
assert isinstance(config, Config)
|
||||||
|
|
||||||
|
|
||||||
def test_model_recommendations():
|
def test_model_recommendations():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user