diff --git a/spacy/cli/_util.py b/spacy/cli/_util.py index 0755ccb46..1db3a1d44 100644 --- a/spacy/cli/_util.py +++ b/spacy/cli/_util.py @@ -336,10 +336,12 @@ def git_sparse_checkout(repo: str, subpath: str, dest: Path, *, branch: str = "m # Now pass those missings into another bit of git internals missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")]) if not missings: - err = f"Could not find any relevant files for '{subpath}'. " \ - f"Did you specify a correct and complete path within repo '{repo}' " \ - f"and branch {branch}?" - msg.fail(err, exits=1) + err = ( + f"Could not find any relevant files for '{subpath}'. " + f"Did you specify a correct and complete path within repo '{repo}' " + f"and branch {branch}?" + ) + msg.fail(err, exits=1) cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}" _attempt_run_command(cmd) # And finally, we can checkout our subpath @@ -348,6 +350,7 @@ def git_sparse_checkout(repo: str, subpath: str, dest: Path, *, branch: str = "m # We need Path(name) to make sure we also support subdirectories shutil.move(str(tmp_dir / Path(subpath)), str(dest)) + def _attempt_run_command(cmd): try: return run_command(cmd, capture=True) @@ -355,6 +358,7 @@ def _attempt_run_command(cmd): err = f"Could not run command: {cmd}." msg.fail(err, exits=1) + def _from_http_to_git(repo): if repo.startswith("http://"): repo = repo.replace(r"http://", r"https://") @@ -364,3 +368,23 @@ def _from_http_to_git(repo): repo = repo[:-1] repo = f"{repo}.git" return repo + + +def string_to_list(value, intify=False): + """Parse a comma-separated string to a list""" + if not value: + return [] + if value.startswith("[") and value.endswith("]"): + value = value[1:-1] + result = [] + for p in value.split(","): + p = p.strip() + if p.startswith("'") and p.endswith("'"): + p = p[1:-1] + if p.startswith('"') and p.endswith('"'): + p = p[1:-1] + p = p.strip() + if intify: + p = int(p) + result.append(p) + return result diff --git a/spacy/cli/debug_model.py b/spacy/cli/debug_model.py index f4d93071e..1a250e43e 100644 --- a/spacy/cli/debug_model.py +++ b/spacy/cli/debug_model.py @@ -5,7 +5,7 @@ from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam from thinc.api import Model, data_validation import typer -from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides +from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides, string_to_list from .. import util @@ -38,12 +38,13 @@ def debug_model_cli( require_gpu(use_gpu) else: msg.info("Using CPU") + layers = string_to_list(layers, intify=True) print_settings = { "dimensions": dimensions, "parameters": parameters, "gradients": gradients, "attributes": attributes, - "layers": [int(x.strip()) for x in layers.split(",")] if layers else [], + "layers": layers, "print_before_training": P0, "print_after_init": P1, "print_after_training": P2, diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index a23b15d53..ec65b0e0a 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -9,7 +9,7 @@ import re from .. import util from ..language import DEFAULT_CONFIG_PRETRAIN_PATH from ..schemas import RecommendationSchema -from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND +from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND, string_to_list ROOT = Path(__file__).parent / "templates" @@ -42,9 +42,7 @@ def init_config_cli( """ if isinstance(optimize, Optimizations): # instance of enum from the CLI optimize = optimize.value - if pipeline.startswith("[") and pipeline.endswith("]"): - pipeline = pipeline[1:-1] - pipeline = [p.strip() for p in pipeline.split(",")] + pipeline = string_to_list(pipeline) init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu) diff --git a/spacy/tests/test_cli.py b/spacy/tests/test_cli.py index e8c83cbad..0df707dc0 100644 --- a/spacy/tests/test_cli.py +++ b/spacy/tests/test_cli.py @@ -9,6 +9,7 @@ from spacy.cli.pretrain import make_docs from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import load_project_config, substitute_project_variables +from spacy.cli._util import string_to_list from thinc.config import ConfigValidationError import srsly @@ -372,17 +373,13 @@ def test_parse_config_overrides(args, expected): assert parse_config_overrides(args) == expected -@pytest.mark.parametrize( - "args", [["--foo"], ["--x.foo", "bar", "--baz"]], -) +@pytest.mark.parametrize("args", [["--foo"], ["--x.foo", "bar", "--baz"]]) def test_parse_config_overrides_invalid(args): with pytest.raises(NoSuchOption): parse_config_overrides(args) -@pytest.mark.parametrize( - "args", [["--x.foo", "bar", "baz"], ["x.foo"]], -) +@pytest.mark.parametrize("args", [["--x.foo", "bar", "baz"], ["x.foo"]]) def test_parse_config_overrides_invalid_2(args): with pytest.raises(SystemExit): parse_config_overrides(args) @@ -401,3 +398,44 @@ def test_init_config(lang, pipeline, optimize): def test_model_recommendations(): for lang, data in RECOMMENDATIONS.items(): assert RecommendationSchema(**data) + + +@pytest.mark.parametrize( + "value", + [ + # fmt: off + "parser,textcat,tagger", + " parser, textcat ,tagger ", + 'parser,textcat,tagger', + ' parser, textcat ,tagger ', + ' "parser"," textcat " ,"tagger "', + " 'parser',' textcat ' ,'tagger '", + '[parser,textcat,tagger]', + '["parser","textcat","tagger"]', + '[" parser" ,"textcat ", " tagger " ]', + "[parser,textcat,tagger]", + "[ parser, textcat , tagger]", + "['parser','textcat','tagger']", + "[' parser' , 'textcat', ' tagger ' ]", + # fmt: on + ], +) +def test_string_to_list(value): + assert string_to_list(value, intify=False) == ["parser", "textcat", "tagger"] + + +@pytest.mark.parametrize( + "value", + [ + # fmt: off + "1,2,3", + '[1,2,3]', + '["1","2","3"]', + '[" 1" ,"2 ", " 3 " ]', + "[' 1' , '2', ' 3 ' ]", + # fmt: on + ], +) +def test_string_to_list_intify(value): + assert string_to_list(value, intify=False) == ["1", "2", "3"] + assert string_to_list(value, intify=True) == [1, 2, 3]