mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Update CLI utils, project.yml schema and add test
This commit is contained in:
parent
8ac5ef1284
commit
dd84577a98
|
@ -195,7 +195,7 @@ def get_checksum(path: Union[Path, str]) -> str:
|
||||||
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
||||||
dir_checksum.update(sub_file.read_bytes())
|
dir_checksum.update(sub_file.read_bytes())
|
||||||
return dir_checksum.hexdigest()
|
return dir_checksum.hexdigest()
|
||||||
raise ValueError(f"Can't get checksum for {path}: not a file or directory")
|
msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -320,9 +320,9 @@ def git_sparse_checkout(
|
||||||
repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None
|
repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None
|
||||||
):
|
):
|
||||||
if dest.exists():
|
if dest.exists():
|
||||||
raise IOError("Destination of checkout must not exist")
|
msg.fail("Destination of checkout must not exist", exits=1)
|
||||||
if not dest.parent.exists():
|
if not dest.parent.exists():
|
||||||
raise IOError("Parent of destination of checkout must exist")
|
msg.fail("Parent of destination of checkout must exist", exits=1)
|
||||||
# We're using Git and sparse checkout to only clone the files we need
|
# We're using Git and sparse checkout to only clone the files we need
|
||||||
with make_tempdir() as tmp_dir:
|
with make_tempdir() as tmp_dir:
|
||||||
cmd = (
|
cmd = (
|
||||||
|
|
|
@ -283,7 +283,15 @@ class ConfigSchema(BaseModel):
|
||||||
# Project config Schema
|
# Project config Schema
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfigAsset(BaseModel):
|
class ProjectConfigAssetGitItem(BaseModel):
|
||||||
|
# fmt: off
|
||||||
|
repo: StrictStr = Field(..., title="URL of Git repo to download from")
|
||||||
|
path: StrictStr = Field(..., title="File path or sub-directory to download (used for sparse checkout)")
|
||||||
|
branch: StrictStr = Field("master", title="Branch to clone from")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectConfigAssetURL(BaseModel):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
dest: StrictStr = Field(..., title="Destination of downloaded asset")
|
dest: StrictStr = Field(..., title="Destination of downloaded asset")
|
||||||
url: Optional[StrictStr] = Field(None, title="URL of asset")
|
url: Optional[StrictStr] = Field(None, title="URL of asset")
|
||||||
|
@ -291,6 +299,13 @@ class ProjectConfigAsset(BaseModel):
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
|
class ProjectConfigAssetGit(BaseModel):
|
||||||
|
# fmt: off
|
||||||
|
git: ProjectConfigAssetGitItem = Field(..., title="Git repo information")
|
||||||
|
checksum: str = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfigCommand(BaseModel):
|
class ProjectConfigCommand(BaseModel):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
name: StrictStr = Field(..., title="Name of command")
|
name: StrictStr = Field(..., title="Name of command")
|
||||||
|
@ -310,7 +325,7 @@ class ProjectConfigCommand(BaseModel):
|
||||||
class ProjectConfigSchema(BaseModel):
|
class ProjectConfigSchema(BaseModel):
|
||||||
# fmt: off
|
# fmt: off
|
||||||
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
||||||
assets: List[ProjectConfigAsset] = Field([], title="Data assets")
|
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
||||||
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
|
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -270,6 +270,41 @@ def test_pretrain_make_docs():
|
||||||
assert skip_count == 0
|
assert skip_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_project_config_validation_full():
|
||||||
|
config = {
|
||||||
|
"vars": {"some_var": 20},
|
||||||
|
"directories": ["assets", "configs", "corpus", "scripts", "training"],
|
||||||
|
"assets": [
|
||||||
|
{
|
||||||
|
"dest": "x",
|
||||||
|
"url": "https://example.com",
|
||||||
|
"checksum": "63373dd656daa1fd3043ce166a59474c",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"dest": "y",
|
||||||
|
"git": {
|
||||||
|
"repo": "https://github.com/example/repo",
|
||||||
|
"branch": "develop",
|
||||||
|
"path": "y",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"commands": [
|
||||||
|
{
|
||||||
|
"name": "train",
|
||||||
|
"help": "Train a model",
|
||||||
|
"script": ["python -m spacy train config.cfg -o training"],
|
||||||
|
"deps": ["config.cfg", "corpus/training.spcy"],
|
||||||
|
"outputs": ["training/model-best"],
|
||||||
|
},
|
||||||
|
{"name": "test", "script": ["pytest", "custom.py"], "no_skip": True},
|
||||||
|
],
|
||||||
|
"workflows": {"all": ["train", "test"], "train": ["train"]},
|
||||||
|
}
|
||||||
|
errors = validate(ProjectConfigSchema, config)
|
||||||
|
assert not errors
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"config",
|
"config",
|
||||||
[
|
[
|
||||||
|
|
Loading…
Reference in New Issue
Block a user