string_to_list to parse comma-separated string into a list

This commit is contained in:
svlandeg 2020-09-12 14:43:22 +02:00
parent 5b94aeece9
commit 115147804a
4 changed files with 77 additions and 16 deletions

View File

@ -336,9 +336,11 @@ def git_sparse_checkout(repo: str, subpath: str, dest: Path, *, branch: str = "m
# Now pass those missings into another bit of git internals # Now pass those missings into another bit of git internals
missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")]) missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
if not missings: if not missings:
err = f"Could not find any relevant files for '{subpath}'. " \ err = (
f"Did you specify a correct and complete path within repo '{repo}' " \ f"Could not find any relevant files for '{subpath}'. "
f"Did you specify a correct and complete path within repo '{repo}' "
f"and branch {branch}?" f"and branch {branch}?"
)
msg.fail(err, exits=1) msg.fail(err, exits=1)
cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}" cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}"
_attempt_run_command(cmd) _attempt_run_command(cmd)
@ -348,6 +350,7 @@ def git_sparse_checkout(repo: str, subpath: str, dest: Path, *, branch: str = "m
# We need Path(name) to make sure we also support subdirectories # We need Path(name) to make sure we also support subdirectories
shutil.move(str(tmp_dir / Path(subpath)), str(dest)) shutil.move(str(tmp_dir / Path(subpath)), str(dest))
def _attempt_run_command(cmd): def _attempt_run_command(cmd):
try: try:
return run_command(cmd, capture=True) return run_command(cmd, capture=True)
@ -355,6 +358,7 @@ def _attempt_run_command(cmd):
err = f"Could not run command: {cmd}." err = f"Could not run command: {cmd}."
msg.fail(err, exits=1) msg.fail(err, exits=1)
def _from_http_to_git(repo): def _from_http_to_git(repo):
if repo.startswith("http://"): if repo.startswith("http://"):
repo = repo.replace(r"http://", r"https://") repo = repo.replace(r"http://", r"https://")
@ -364,3 +368,23 @@ def _from_http_to_git(repo):
repo = repo[:-1] repo = repo[:-1]
repo = f"{repo}.git" repo = f"{repo}.git"
return repo return repo
def string_to_list(value, intify=False):
"""Parse a comma-separated string to a list"""
if not value:
return []
if value.startswith("[") and value.endswith("]"):
value = value[1:-1]
result = []
for p in value.split(","):
p = p.strip()
if p.startswith("'") and p.endswith("'"):
p = p[1:-1]
if p.startswith('"') and p.endswith('"'):
p = p[1:-1]
p = p.strip()
if intify:
p = int(p)
result.append(p)
return result

View File

@ -5,7 +5,7 @@ from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
from thinc.api import Model, data_validation from thinc.api import Model, data_validation
import typer import typer
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides, string_to_list
from .. import util from .. import util
@ -38,12 +38,13 @@ def debug_model_cli(
require_gpu(use_gpu) require_gpu(use_gpu)
else: else:
msg.info("Using CPU") msg.info("Using CPU")
layers = string_to_list(layers, intify=True)
print_settings = { print_settings = {
"dimensions": dimensions, "dimensions": dimensions,
"parameters": parameters, "parameters": parameters,
"gradients": gradients, "gradients": gradients,
"attributes": attributes, "attributes": attributes,
"layers": [int(x.strip()) for x in layers.split(",")] if layers else [], "layers": layers,
"print_before_training": P0, "print_before_training": P0,
"print_after_init": P1, "print_after_init": P1,
"print_after_training": P2, "print_after_training": P2,

View File

@ -9,7 +9,7 @@ import re
from .. import util from .. import util
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
from ..schemas import RecommendationSchema from ..schemas import RecommendationSchema
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND, string_to_list
ROOT = Path(__file__).parent / "templates" ROOT = Path(__file__).parent / "templates"
@ -42,9 +42,7 @@ def init_config_cli(
""" """
if isinstance(optimize, Optimizations): # instance of enum from the CLI if isinstance(optimize, Optimizations): # instance of enum from the CLI
optimize = optimize.value optimize = optimize.value
if pipeline.startswith("[") and pipeline.endswith("]"): pipeline = string_to_list(pipeline)
pipeline = pipeline[1:-1]
pipeline = [p.strip() for p in pipeline.split(",")]
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu) init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)

View File

@ -9,6 +9,7 @@ from spacy.cli.pretrain import make_docs
from spacy.cli.init_config import init_config, RECOMMENDATIONS from spacy.cli.init_config import init_config, RECOMMENDATIONS
from spacy.cli._util import validate_project_commands, parse_config_overrides from spacy.cli._util import validate_project_commands, parse_config_overrides
from spacy.cli._util import load_project_config, substitute_project_variables from spacy.cli._util import load_project_config, substitute_project_variables
from spacy.cli._util import string_to_list
from thinc.config import ConfigValidationError from thinc.config import ConfigValidationError
import srsly import srsly
@ -372,17 +373,13 @@ def test_parse_config_overrides(args, expected):
assert parse_config_overrides(args) == expected assert parse_config_overrides(args) == expected
@pytest.mark.parametrize( @pytest.mark.parametrize("args", [["--foo"], ["--x.foo", "bar", "--baz"]])
"args", [["--foo"], ["--x.foo", "bar", "--baz"]],
)
def test_parse_config_overrides_invalid(args): def test_parse_config_overrides_invalid(args):
with pytest.raises(NoSuchOption): with pytest.raises(NoSuchOption):
parse_config_overrides(args) parse_config_overrides(args)
@pytest.mark.parametrize( @pytest.mark.parametrize("args", [["--x.foo", "bar", "baz"], ["x.foo"]])
"args", [["--x.foo", "bar", "baz"], ["x.foo"]],
)
def test_parse_config_overrides_invalid_2(args): def test_parse_config_overrides_invalid_2(args):
with pytest.raises(SystemExit): with pytest.raises(SystemExit):
parse_config_overrides(args) parse_config_overrides(args)
@ -401,3 +398,44 @@ def test_init_config(lang, pipeline, optimize):
def test_model_recommendations(): def test_model_recommendations():
for lang, data in RECOMMENDATIONS.items(): for lang, data in RECOMMENDATIONS.items():
assert RecommendationSchema(**data) assert RecommendationSchema(**data)
@pytest.mark.parametrize(
"value",
[
# fmt: off
"parser,textcat,tagger",
" parser, textcat ,tagger ",
'parser,textcat,tagger',
' parser, textcat ,tagger ',
' "parser"," textcat " ,"tagger "',
" 'parser',' textcat ' ,'tagger '",
'[parser,textcat,tagger]',
'["parser","textcat","tagger"]',
'[" parser" ,"textcat ", " tagger " ]',
"[parser,textcat,tagger]",
"[ parser, textcat , tagger]",
"['parser','textcat','tagger']",
"[' parser' , 'textcat', ' tagger ' ]",
# fmt: on
],
)
def test_string_to_list(value):
assert string_to_list(value, intify=False) == ["parser", "textcat", "tagger"]
@pytest.mark.parametrize(
"value",
[
# fmt: off
"1,2,3",
'[1,2,3]',
'["1","2","3"]',
'[" 1" ,"2 ", " 3 " ]',
"[' 1' , '2', ' 3 ' ]",
# fmt: on
],
)
def test_string_to_list_intify(value):
assert string_to_list(value, intify=False) == ["1", "2", "3"]
assert string_to_list(value, intify=True) == [1, 2, 3]