mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	Update CLI utils, project.yml schema and add test
This commit is contained in:
		
							parent
							
								
									8ac5ef1284
								
							
						
					
					
						commit
						dd84577a98
					
				|  | @ -195,7 +195,7 @@ def get_checksum(path: Union[Path, str]) -> str: | |||
|         for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()): | ||||
|             dir_checksum.update(sub_file.read_bytes()) | ||||
|         return dir_checksum.hexdigest() | ||||
|     raise ValueError(f"Can't get checksum for {path}: not a file or directory") | ||||
|     msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1) | ||||
| 
 | ||||
| 
 | ||||
| @contextmanager | ||||
|  | @ -320,9 +320,9 @@ def git_sparse_checkout( | |||
|     repo: str, subpath: str, dest: Path, *, branch: Optional[str] = None | ||||
| ): | ||||
|     if dest.exists(): | ||||
|         raise IOError("Destination of checkout must not exist") | ||||
|         msg.fail("Destination of checkout must not exist", exits=1) | ||||
|     if not dest.parent.exists(): | ||||
|         raise IOError("Parent of destination of checkout must exist") | ||||
|         msg.fail("Parent of destination of checkout must exist", exits=1) | ||||
|     # We're using Git and sparse checkout to only clone the files we need | ||||
|     with make_tempdir() as tmp_dir: | ||||
|         cmd = ( | ||||
|  |  | |||
|  | @ -283,7 +283,15 @@ class ConfigSchema(BaseModel): | |||
| # Project config Schema | ||||
| 
 | ||||
| 
 | ||||
| class ProjectConfigAsset(BaseModel): | ||||
| class ProjectConfigAssetGitItem(BaseModel): | ||||
|     # fmt: off | ||||
|     repo: StrictStr = Field(..., title="URL of Git repo to download from") | ||||
|     path: StrictStr = Field(..., title="File path or sub-directory to download (used for sparse checkout)") | ||||
|     branch: StrictStr = Field("master", title="Branch to clone from") | ||||
|     # fmt: on | ||||
| 
 | ||||
| 
 | ||||
| class ProjectConfigAssetURL(BaseModel): | ||||
|     # fmt: off | ||||
|     dest: StrictStr = Field(..., title="Destination of downloaded asset") | ||||
|     url: Optional[StrictStr] = Field(None, title="URL of asset") | ||||
|  | @ -291,6 +299,13 @@ class ProjectConfigAsset(BaseModel): | |||
|     # fmt: on | ||||
| 
 | ||||
| 
 | ||||
| class ProjectConfigAssetGit(BaseModel): | ||||
|     # fmt: off | ||||
|     git: ProjectConfigAssetGitItem = Field(..., title="Git repo information") | ||||
|     checksum: str = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})") | ||||
|     # fmt: on | ||||
| 
 | ||||
| 
 | ||||
| class ProjectConfigCommand(BaseModel): | ||||
|     # fmt: off | ||||
|     name: StrictStr = Field(..., title="Name of command") | ||||
|  | @ -310,7 +325,7 @@ class ProjectConfigCommand(BaseModel): | |||
| class ProjectConfigSchema(BaseModel): | ||||
|     # fmt: off | ||||
|     vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands") | ||||
|     assets: List[ProjectConfigAsset] = Field([], title="Data assets") | ||||
|     assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets") | ||||
|     workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order") | ||||
|     commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts") | ||||
|     # fmt: on | ||||
|  |  | |||
|  | @ -270,6 +270,41 @@ def test_pretrain_make_docs(): | |||
|     assert skip_count == 0 | ||||
| 
 | ||||
| 
 | ||||
| def test_project_config_validation_full(): | ||||
|     config = { | ||||
|         "vars": {"some_var": 20}, | ||||
|         "directories": ["assets", "configs", "corpus", "scripts", "training"], | ||||
|         "assets": [ | ||||
|             { | ||||
|                 "dest": "x", | ||||
|                 "url": "https://example.com", | ||||
|                 "checksum": "63373dd656daa1fd3043ce166a59474c", | ||||
|             }, | ||||
|             { | ||||
|                 "dest": "y", | ||||
|                 "git": { | ||||
|                     "repo": "https://github.com/example/repo", | ||||
|                     "branch": "develop", | ||||
|                     "path": "y", | ||||
|                 }, | ||||
|             }, | ||||
|         ], | ||||
|         "commands": [ | ||||
|             { | ||||
|                 "name": "train", | ||||
|                 "help": "Train a model", | ||||
|                 "script": ["python -m spacy train config.cfg -o training"], | ||||
|                 "deps": ["config.cfg", "corpus/training.spcy"], | ||||
|                 "outputs": ["training/model-best"], | ||||
|             }, | ||||
|             {"name": "test", "script": ["pytest", "custom.py"], "no_skip": True}, | ||||
|         ], | ||||
|         "workflows": {"all": ["train", "test"], "train": ["train"]}, | ||||
|     } | ||||
|     errors = validate(ProjectConfigSchema, config) | ||||
|     assert not errors | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "config", | ||||
|     [ | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user