Merge pull request #6051 from svlandeg/feature/cli-config

This commit is contained in:
Ines Montani 2020-09-12 17:12:35 +02:00 committed by GitHub
commit b41be87213
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 69 additions and 10 deletions

View File

@ -385,3 +385,23 @@ def _from_http_to_git(repo: str) -> str:
repo = repo[:-1]
repo = f"{repo}.git"
return repo
def string_to_list(value, intify=False):
"""Parse a comma-separated string to a list"""
if not value:
return []
if value.startswith("[") and value.endswith("]"):
value = value[1:-1]
result = []
for p in value.split(","):
p = p.strip()
if p.startswith("'") and p.endswith("'"):
p = p[1:-1]
if p.startswith('"') and p.endswith('"'):
p = p[1:-1]
p = p.strip()
if intify:
p = int(p)
result.append(p)
return result

View File

@ -5,7 +5,7 @@ from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
from thinc.api import Model, data_validation
import typer
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides, string_to_list
from .. import util
@ -38,12 +38,13 @@ def debug_model_cli(
require_gpu(use_gpu)
else:
msg.info("Using CPU")
layers = string_to_list(layers, intify=True)
print_settings = {
"dimensions": dimensions,
"parameters": parameters,
"gradients": gradients,
"attributes": attributes,
"layers": [int(x.strip()) for x in layers.split(",")] if layers else [],
"layers": layers,
"print_before_training": P0,
"print_after_init": P1,
"print_after_training": P2,

View File

@ -9,7 +9,7 @@ import re
from .. import util
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
from ..schemas import RecommendationSchema
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND, string_to_list
ROOT = Path(__file__).parent / "templates"
@ -42,7 +42,7 @@ def init_config_cli(
"""
if isinstance(optimize, Optimizations): # instance of enum from the CLI
optimize = optimize.value
pipeline = [p.strip() for p in pipeline.split(",")]
pipeline = string_to_list(pipeline)
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)

View File

@ -9,6 +9,7 @@ from spacy.cli.pretrain import make_docs
from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables
from spacy.cli._util import string_to_list
from thinc.config import ConfigValidationError
import srsly
@ -372,17 +373,13 @@ def test_parse_config_overrides(args, expected):
assert parse_config_overrides(args) == expected
@pytest.mark.parametrize(
"args", [["--foo"], ["--x.foo", "bar", "--baz"]],
)
@pytest.mark.parametrize("args", [["--foo"], ["--x.foo", "bar", "--baz"]])
def test_parse_config_overrides_invalid(args):
with pytest.raises(NoSuchOption):
parse_config_overrides(args)
@pytest.mark.parametrize(
"args", [["--x.foo", "bar", "baz"], ["x.foo"]],
)
@pytest.mark.parametrize("args", [["--x.foo", "bar", "baz"], ["x.foo"]])
def test_parse_config_overrides_invalid_2(args):
with pytest.raises(SystemExit):
parse_config_overrides(args)
@ -401,3 +398,44 @@ def test_init_config(lang, pipeline, optimize):
def test_model_recommendations():
for lang, data in RECOMMENDATIONS.items():
assert RecommendationSchema(**data)
@pytest.mark.parametrize(
"value",
[
# fmt: off
"parser,textcat,tagger",
" parser, textcat ,tagger ",
'parser,textcat,tagger',
' parser, textcat ,tagger ',
' "parser"," textcat " ,"tagger "',
" 'parser',' textcat ' ,'tagger '",
'[parser,textcat,tagger]',
'["parser","textcat","tagger"]',
'[" parser" ,"textcat ", " tagger " ]',
"[parser,textcat,tagger]",
"[ parser, textcat , tagger]",
"['parser','textcat','tagger']",
"[' parser' , 'textcat', ' tagger ' ]",
# fmt: on
],
)
def test_string_to_list(value):
assert string_to_list(value, intify=False) == ["parser", "textcat", "tagger"]
@pytest.mark.parametrize(
"value",
[
# fmt: off
"1,2,3",
'[1,2,3]',
'["1","2","3"]',
'[" 1" ,"2 ", " 3 " ]',
"[' 1' , '2', ' 3 ' ]",
# fmt: on
],
)
def test_string_to_list_intify(value):
assert string_to_list(value, intify=False) == ["1", "2", "3"]
assert string_to_list(value, intify=True) == [1, 2, 3]