mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Merge pull request #6051 from svlandeg/feature/cli-config
This commit is contained in:
commit
b41be87213
|
@ -385,3 +385,23 @@ def _from_http_to_git(repo: str) -> str:
|
||||||
repo = repo[:-1]
|
repo = repo[:-1]
|
||||||
repo = f"{repo}.git"
|
repo = f"{repo}.git"
|
||||||
return repo
|
return repo
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_list(value, intify=False):
|
||||||
|
"""Parse a comma-separated string to a list"""
|
||||||
|
if not value:
|
||||||
|
return []
|
||||||
|
if value.startswith("[") and value.endswith("]"):
|
||||||
|
value = value[1:-1]
|
||||||
|
result = []
|
||||||
|
for p in value.split(","):
|
||||||
|
p = p.strip()
|
||||||
|
if p.startswith("'") and p.endswith("'"):
|
||||||
|
p = p[1:-1]
|
||||||
|
if p.startswith('"') and p.endswith('"'):
|
||||||
|
p = p[1:-1]
|
||||||
|
p = p.strip()
|
||||||
|
if intify:
|
||||||
|
p = int(p)
|
||||||
|
result.append(p)
|
||||||
|
return result
|
||||||
|
|
|
@ -5,7 +5,7 @@ from thinc.api import require_gpu, fix_random_seed, set_dropout_rate, Adam
|
||||||
from thinc.api import Model, data_validation
|
from thinc.api import Model, data_validation
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides
|
from ._util import Arg, Opt, debug_cli, show_validation_error, parse_config_overrides, string_to_list
|
||||||
from .. import util
|
from .. import util
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,12 +38,13 @@ def debug_model_cli(
|
||||||
require_gpu(use_gpu)
|
require_gpu(use_gpu)
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
|
layers = string_to_list(layers, intify=True)
|
||||||
print_settings = {
|
print_settings = {
|
||||||
"dimensions": dimensions,
|
"dimensions": dimensions,
|
||||||
"parameters": parameters,
|
"parameters": parameters,
|
||||||
"gradients": gradients,
|
"gradients": gradients,
|
||||||
"attributes": attributes,
|
"attributes": attributes,
|
||||||
"layers": [int(x.strip()) for x in layers.split(",")] if layers else [],
|
"layers": layers,
|
||||||
"print_before_training": P0,
|
"print_before_training": P0,
|
||||||
"print_after_init": P1,
|
"print_after_init": P1,
|
||||||
"print_after_training": P2,
|
"print_after_training": P2,
|
||||||
|
|
|
@ -9,7 +9,7 @@ import re
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
|
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
|
||||||
from ..schemas import RecommendationSchema
|
from ..schemas import RecommendationSchema
|
||||||
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
|
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND, string_to_list
|
||||||
|
|
||||||
|
|
||||||
ROOT = Path(__file__).parent / "templates"
|
ROOT = Path(__file__).parent / "templates"
|
||||||
|
@ -42,7 +42,7 @@ def init_config_cli(
|
||||||
"""
|
"""
|
||||||
if isinstance(optimize, Optimizations): # instance of enum from the CLI
|
if isinstance(optimize, Optimizations): # instance of enum from the CLI
|
||||||
optimize = optimize.value
|
optimize = optimize.value
|
||||||
pipeline = [p.strip() for p in pipeline.split(",")]
|
pipeline = string_to_list(pipeline)
|
||||||
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)
|
init_config(output_file, lang=lang, pipeline=pipeline, optimize=optimize, cpu=cpu)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,7 @@ from spacy.cli.pretrain import make_docs
|
||||||
from spacy.cli.init_config import init_config, RECOMMENDATIONS
|
from spacy.cli.init_config import init_config, RECOMMENDATIONS
|
||||||
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
from spacy.cli._util import validate_project_commands, parse_config_overrides
|
||||||
from spacy.cli._util import load_project_config, substitute_project_variables
|
from spacy.cli._util import load_project_config, substitute_project_variables
|
||||||
|
from spacy.cli._util import string_to_list
|
||||||
from thinc.config import ConfigValidationError
|
from thinc.config import ConfigValidationError
|
||||||
import srsly
|
import srsly
|
||||||
|
|
||||||
|
@ -372,17 +373,13 @@ def test_parse_config_overrides(args, expected):
|
||||||
assert parse_config_overrides(args) == expected
|
assert parse_config_overrides(args) == expected
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("args", [["--foo"], ["--x.foo", "bar", "--baz"]])
|
||||||
"args", [["--foo"], ["--x.foo", "bar", "--baz"]],
|
|
||||||
)
|
|
||||||
def test_parse_config_overrides_invalid(args):
|
def test_parse_config_overrides_invalid(args):
|
||||||
with pytest.raises(NoSuchOption):
|
with pytest.raises(NoSuchOption):
|
||||||
parse_config_overrides(args)
|
parse_config_overrides(args)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize("args", [["--x.foo", "bar", "baz"], ["x.foo"]])
|
||||||
"args", [["--x.foo", "bar", "baz"], ["x.foo"]],
|
|
||||||
)
|
|
||||||
def test_parse_config_overrides_invalid_2(args):
|
def test_parse_config_overrides_invalid_2(args):
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
parse_config_overrides(args)
|
parse_config_overrides(args)
|
||||||
|
@ -401,3 +398,44 @@ def test_init_config(lang, pipeline, optimize):
|
||||||
def test_model_recommendations():
|
def test_model_recommendations():
|
||||||
for lang, data in RECOMMENDATIONS.items():
|
for lang, data in RECOMMENDATIONS.items():
|
||||||
assert RecommendationSchema(**data)
|
assert RecommendationSchema(**data)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value",
|
||||||
|
[
|
||||||
|
# fmt: off
|
||||||
|
"parser,textcat,tagger",
|
||||||
|
" parser, textcat ,tagger ",
|
||||||
|
'parser,textcat,tagger',
|
||||||
|
' parser, textcat ,tagger ',
|
||||||
|
' "parser"," textcat " ,"tagger "',
|
||||||
|
" 'parser',' textcat ' ,'tagger '",
|
||||||
|
'[parser,textcat,tagger]',
|
||||||
|
'["parser","textcat","tagger"]',
|
||||||
|
'[" parser" ,"textcat ", " tagger " ]',
|
||||||
|
"[parser,textcat,tagger]",
|
||||||
|
"[ parser, textcat , tagger]",
|
||||||
|
"['parser','textcat','tagger']",
|
||||||
|
"[' parser' , 'textcat', ' tagger ' ]",
|
||||||
|
# fmt: on
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_string_to_list(value):
|
||||||
|
assert string_to_list(value, intify=False) == ["parser", "textcat", "tagger"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"value",
|
||||||
|
[
|
||||||
|
# fmt: off
|
||||||
|
"1,2,3",
|
||||||
|
'[1,2,3]',
|
||||||
|
'["1","2","3"]',
|
||||||
|
'[" 1" ,"2 ", " 3 " ]',
|
||||||
|
"[' 1' , '2', ' 3 ' ]",
|
||||||
|
# fmt: on
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_string_to_list_intify(value):
|
||||||
|
assert string_to_list(value, intify=False) == ["1", "2", "3"]
|
||||||
|
assert string_to_list(value, intify=True) == [1, 2, 3]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user