diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index c64aa1507..360d2439a 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -385,3 +385,23 @@ def _from_http_to_git(repo: str) -> str: repo = repo[:-1] repo = f"{repo}.git" return repo + + +def string_to_list(value, intify=False): + """Parse a comma-separated string to a list""" + if not value: + return [] + if value.startswith("[") and value.endswith("]"): + value = value[1:-1] + result = [] + for p in value.split(","): + p = p.strip() + if p.startswith("'") and p.endswith("'"): + p = p[1:-1] + if p.startswith('"') and p.endswith('"'): + p = p[1:-1] + p = p.strip() + if intify: + p = int(p) + result.append(p) + return result diff --git a/spacy/cli/debug_model.py b/spacy/cli/debug_model.py index f4d93071e..1a250e43e 100644 --- a/spacy/cli/debug_model.py +++ b/spacy/cli/debug_model.py @@ -5,7 +5,7 @@ from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam from thinc.api import Model, data_validation import typer -from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides +from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides, string_to_list from .. import util @@ -38,12 +38,13 @@ def debug_model_cli( require_gpu(use_gpu) else: msg.info("Using CPU") + layers = string_to_list(layers, intify=True) print_settings = { "dimensions": dimensions, "parameters": parameters, "gradients": gradients, "attributes": attributes, - "layers": [int(x.strip()) for x in layers.split(",")] if layers else [], + "layers": layers, "print_before_training": P0, "print_after_init": P1, "print_after_training": P2, diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index 584ca7f64..ec65b0e0a 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -9,7 +9,7 @@ import re from .. import util from ..language import DEFAULT_CONFIG_PRETRAIN_PATH from ..schemas import RecommendationSchema -from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND +from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND, string_to_list ROOT = Path(__file__).parent / "templates" @@ -42,7 +42,7 @@ def init_config_cli( """ if isinstance(optimize, Optimizations): # instance of enum from the CLI optimize = optimize.value - pipeline = [p.strip() for p in pipeline.split(",")] + pipeline = string_to_list(pipeline) init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index e8c83cbad..0df707dc0 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -9,6 +9,7 @@ from spacy.cli.pretrain import make_docs from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import load_project_config, substitute_project_variables +from spacy.cli._util import string_to_list from thinc.config import ConfigValidationError import srsly @@ -372,17 +373,13 @@ def test_parse_config_overrides(args, expected): assert parse_config_overrides(args) == expected -@pytest.mark.parametrize( - "args", [["--foo"], ["--x.foo", "bar", "--baz"]], -) +@pytest.mark.parametrize("args", [["--foo"], ["--x.foo", "bar", "--baz"]]) def test_parse_config_overrides_invalid(args): with pytest.raises(NoSuchOption): parse_config_overrides(args) -@pytest.mark.parametrize( - "args", [["--x.foo", "bar", "baz"], ["x.foo"]], -) +@pytest.mark.parametrize("args", [["--x.foo", "bar", "baz"], ["x.foo"]]) def test_parse_config_overrides_invalid_2(args): with pytest.raises(SystemExit): parse_config_overrides(args) @@ -401,3 +398,44 @@ def test_init_config(lang, pipeline, optimize): def test_model_recommendations(): for lang, data in RECOMMENDATIONS.items(): assert RecommendationSchema(**data) + + +@pytest.mark.parametrize( + "value", + [ + # fmt: off + "parser,textcat,tagger", + " parser, textcat ,tagger ", + 'parser,textcat,tagger', + ' parser, textcat ,tagger ', + ' "parser"," textcat " ,"tagger "', + " 'parser',' textcat ' ,'tagger '", + '[parser,textcat,tagger]', + '["parser","textcat","tagger"]', + '[" parser" ,"textcat ", " tagger " ]', + "[parser,textcat,tagger]", + "[ parser, textcat , tagger]", + "['parser','textcat','tagger']", + "[' parser' , 'textcat', ' tagger ' ]", + # fmt: on + ], +) +def test_string_to_list(value): + assert string_to_list(value, intify=False) == ["parser", "textcat", "tagger"] + + +@pytest.mark.parametrize( + "value", + [ + # fmt: off + "1,2,3", + '[1,2,3]', + '["1","2","3"]', + '[" 1" ,"2 ", " 3 " ]', + "[' 1' , '2', ' 3 ' ]", + # fmt: on + ], +) +def test_string_to_list_intify(value): + assert string_to_list(value, intify=False) == ["1", "2", "3"] + assert string_to_list(value, intify=True) == [1, 2, 3]