Update CLI utils, project.yml schema and add test

This commit is contained in:
Ines Montani 2020-08-25 11:54:53 +02:00
parent 8ac5ef1284
commit dd84577a98
3 changed files with 55 additions and 5 deletions

View File

@ -195,7 +195,7 @@ def get_checksum(path: Union[Path, str]) -> str:
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()): for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
dir_checksum.update(sub_file.read_bytes()) dir_checksum.update(sub_file.read_bytes())
return dir_checksum.hexdigest() return dir_checksum.hexdigest()
raise ValueError(f"Can't get checksum for {path}: not a file or directory") msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
@contextmanager @contextmanager
@ -320,9 +320,9 @@ def git_sparse_checkout(
repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None
): ):
if dest.exists(): if dest.exists():
raise IOError("Destination of checkout must not exist") msg.fail("Destination of checkout must not exist", exits=1)
if not dest.parent.exists(): if not dest.parent.exists():
raise IOError("Parent of destination of checkout must exist") msg.fail("Parent of destination of checkout must exist", exits=1)
# We're using Git and sparse checkout to only clone the files we need # We're using Git and sparse checkout to only clone the files we need
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
cmd = ( cmd = (

View File

@ -283,7 +283,15 @@ class ConfigSchema(BaseModel):
# Project config Schema # Project config Schema
class ProjectConfigAsset(BaseModel): class ProjectConfigAssetGitItem(BaseModel):
# fmt: off
repo: StrictStr = Field(..., title="URL of Git repo to download from")
path: StrictStr = Field(..., title="File path or sub-directory to download (used for sparse checkout)")
branch: StrictStr = Field("master", title="Branch to clone from")
# fmt: on
class ProjectConfigAssetURL(BaseModel):
# fmt: off # fmt: off
dest: StrictStr = Field(..., title="Destination of downloaded asset") dest: StrictStr = Field(..., title="Destination of downloaded asset")
url: Optional[StrictStr] = Field(None, title="URL of asset") url: Optional[StrictStr] = Field(None, title="URL of asset")
@ -291,6 +299,13 @@ class ProjectConfigAsset(BaseModel):
# fmt: on # fmt: on
class ProjectConfigAssetGit(BaseModel):
# fmt: off
git: ProjectConfigAssetGitItem = Field(..., title="Git repo information")
checksum: str = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
# fmt: on
class ProjectConfigCommand(BaseModel): class ProjectConfigCommand(BaseModel):
# fmt: off # fmt: off
name: StrictStr = Field(..., title="Name of command") name: StrictStr = Field(..., title="Name of command")
@ -310,7 +325,7 @@ class ProjectConfigCommand(BaseModel):
class ProjectConfigSchema(BaseModel): class ProjectConfigSchema(BaseModel):
# fmt: off # fmt: off
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands") vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
assets: List[ProjectConfigAsset] = Field([], title="Data assets") assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order") workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts") commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
# fmt: on # fmt: on

View File

@ -270,6 +270,41 @@ def test_pretrain_make_docs():
assert skip_count == 0 assert skip_count == 0
def test_project_config_validation_full():
config = {
"vars": {"some_var": 20},
"directories": ["assets", "configs", "corpus", "scripts", "training"],
"assets": [
{
"dest": "x",
"url": "https://example.com",
"checksum": "63373dd656daa1fd3043ce166a59474c",
},
{
"dest": "y",
"git": {
"repo": "https://github.com/example/repo",
"branch": "develop",
"path": "y",
},
},
],
"commands": [
{
"name": "train",
"help": "Train a model",
"script": ["python -m spacy train config.cfg -o training"],
"deps": ["config.cfg", "corpus/training.spcy"],
"outputs": ["training/model-best"],
},
{"name": "test", "script": ["pytest", "custom.py"], "no_skip": True},
],
"workflows": {"all": ["train", "test"], "train": ["train"]},
}
errors = validate(ProjectConfigSchema, config)
assert not errors
@pytest.mark.parametrize( @pytest.mark.parametrize(
"config", "config",
[ [