mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	* Add utils for working with remote storage * WIP add remote_cache for project * WIP add push and pull commands * Use pathy in remote_cache * Updarte util * Update remote_cache * Update util * Update project assets * Update pull script * Update push script * Fix type annotation in util * Work on remote storage * Remove site and env hash * Fix imports * Fix type annotation * Require pathy * Require pathy * Fix import * Add a util to handle project variable substitution * Import push and pull commands * Fix pull command * Fix push command * Fix tarfile in remote_storage * Improve printing * Fiddle with status messages * Set version to v3.0.0a9 * Draft docs for spacy project remote storages * Update docs [ci skip] * Use Thinc config to simplify and unify template variables * Auto-format * Don't import Pathy globally for now Causes slow and annoying Google Cloud warning * Tidy up test * Tidy up and update tests * Update to latest Thinc * Update docs * variables -> vars * Update docs [ci skip] * Update docs [ci skip] Co-authored-by: Ines Montani <ines@ines.io>
		
			
				
	
	
		
			170 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			170 lines
		
	
	
		
			6.4 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Optional, List, Dict, TYPE_CHECKING
 | |
| import os
 | |
| import site
 | |
| import hashlib
 | |
| import urllib.parse
 | |
| import tarfile
 | |
| from pathlib import Path
 | |
| 
 | |
| from .._util import get_hash, get_checksum, download_file, ensure_pathy
 | |
| from ...util import make_tempdir
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from pathy import Pathy  # noqa: F401
 | |
| 
 | |
| 
 | |
| class RemoteStorage:
 | |
|     """Push and pull outputs to and from a remote file storage.
 | |
| 
 | |
|     Remotes can be anything that `smart-open` can support: AWS, GCS, file system,
 | |
|     ssh, etc.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, project_root: Path, url: str, *, compression="gz"):
 | |
|         self.root = project_root
 | |
|         self.url = ensure_pathy(url)
 | |
|         self.compression = compression
 | |
| 
 | |
|     def push(self, path: Path, command_hash: str, content_hash: str) -> "Pathy":
 | |
|         """Compress a file or directory within a project and upload it to a remote
 | |
|         storage. If an object exists at the full URL, nothing is done.
 | |
| 
 | |
|         Within the remote storage, files are addressed by their project path
 | |
|         (url encoded) and two user-supplied hashes, representing their creation
 | |
|         context and their file contents. If the URL already exists, the data is
 | |
|         not uploaded. Paths are archived and compressed prior to upload.
 | |
|         """
 | |
|         loc = self.root / path
 | |
|         if not loc.exists():
 | |
|             raise IOError(f"Cannot push {loc}: does not exist.")
 | |
|         url = self.make_url(path, command_hash, content_hash)
 | |
|         if url.exists():
 | |
|             return None
 | |
|         tmp: Path
 | |
|         with make_tempdir() as tmp:
 | |
|             tar_loc = tmp / self.encode_name(str(path))
 | |
|             mode_string = f"w:{self.compression}" if self.compression else "w"
 | |
|             with tarfile.open(tar_loc, mode=mode_string) as tar_file:
 | |
|                 tar_file.add(str(loc), arcname=str(path))
 | |
|             with tar_loc.open(mode="rb") as input_file:
 | |
|                 with url.open(mode="wb") as output_file:
 | |
|                     output_file.write(input_file.read())
 | |
|         return url
 | |
| 
 | |
|     def pull(
 | |
|         self,
 | |
|         path: Path,
 | |
|         *,
 | |
|         command_hash: Optional[str] = None,
 | |
|         content_hash: Optional[str] = None,
 | |
|     ) -> Optional["Pathy"]:
 | |
|         """Retrieve a file from the remote cache. If the file already exists,
 | |
|         nothing is done.
 | |
| 
 | |
|         If the command_hash and/or content_hash are specified, only matching
 | |
|         results are returned. If no results are available, an error is raised.
 | |
|         """
 | |
|         dest = self.root / path
 | |
|         if dest.exists():
 | |
|             return None
 | |
|         url = self.find(path, command_hash=command_hash, content_hash=content_hash)
 | |
|         if url is None:
 | |
|             return url
 | |
|         else:
 | |
|             # Make sure the destination exists
 | |
|             if not dest.parent.exists():
 | |
|                 dest.parent.mkdir(parents=True)
 | |
|             tmp: Path
 | |
|             with make_tempdir() as tmp:
 | |
|                 tar_loc = tmp / url.parts[-1]
 | |
|                 download_file(url, tar_loc)
 | |
|                 mode_string = f"r:{self.compression}" if self.compression else "r"
 | |
|                 with tarfile.open(tar_loc, mode=mode_string) as tar_file:
 | |
|                     # This requires that the path is added correctly, relative
 | |
|                     # to root. This is how we set things up in push()
 | |
|                     tar_file.extractall(self.root)
 | |
|         return url
 | |
| 
 | |
|     def find(
 | |
|         self,
 | |
|         path: Path,
 | |
|         *,
 | |
|         command_hash: Optional[str] = None,
 | |
|         content_hash: Optional[str] = None,
 | |
|     ) -> Optional["Pathy"]:
 | |
|         """Find the best matching version of a file within the storage,
 | |
|         or `None` if no match can be found. If both the creation and content hash
 | |
|         are specified, only exact matches will be returned. Otherwise, the most
 | |
|         recent matching file is preferred.
 | |
|         """
 | |
|         name = self.encode_name(str(path))
 | |
|         if command_hash is not None and content_hash is not None:
 | |
|             url = self.make_url(path, command_hash, content_hash)
 | |
|             urls = [url] if url.exists() else []
 | |
|         elif command_hash is not None:
 | |
|             urls = list((self.url / name / command_hash).iterdir())
 | |
|         else:
 | |
|             urls = list((self.url / name).iterdir())
 | |
|             if content_hash is not None:
 | |
|                 urls = [url for url in urls if url.parts[-1] == content_hash]
 | |
|         return urls[-1] if urls else None
 | |
| 
 | |
|     def make_url(self, path: Path, command_hash: str, content_hash: str) -> "Pathy":
 | |
|         """Construct a URL from a subpath, a creation hash and a content hash."""
 | |
|         return self.url / self.encode_name(str(path)) / command_hash / content_hash
 | |
| 
 | |
|     def encode_name(self, name: str) -> str:
 | |
|         """Encode a subpath into a URL-safe name."""
 | |
|         return urllib.parse.quote_plus(name)
 | |
| 
 | |
| 
 | |
| def get_content_hash(loc: Path) -> str:
 | |
|     return get_checksum(loc)
 | |
| 
 | |
| 
 | |
| def get_command_hash(
 | |
|     site_hash: str, env_hash: str, deps: List[Path], cmd: List[str]
 | |
| ) -> str:
 | |
|     """Create a hash representing the execution of a command. This includes the
 | |
|     currently installed packages, whatever environment variables have been marked
 | |
|     as relevant, and the command.
 | |
|     """
 | |
|     hashes = [site_hash, env_hash] + [get_checksum(dep) for dep in sorted(deps)]
 | |
|     hashes.extend(cmd)
 | |
|     creation_bytes = "".join(hashes).encode("utf8")
 | |
|     return hashlib.md5(creation_bytes).hexdigest()
 | |
| 
 | |
| 
 | |
| def get_site_hash():
 | |
|     """Hash the current Python environment's site-packages contents, including
 | |
|     the name and version of the libraries. The list we're hashing is what
 | |
|     `pip freeze` would output.
 | |
|     """
 | |
|     site_dirs = site.getsitepackages()
 | |
|     if site.ENABLE_USER_SITE:
 | |
|         site_dirs.extend(site.getusersitepackages())
 | |
|     packages = set()
 | |
|     for site_dir in site_dirs:
 | |
|         site_dir = Path(site_dir)
 | |
|         for subpath in site_dir.iterdir():
 | |
|             if subpath.parts[-1].endswith("dist-info"):
 | |
|                 packages.add(subpath.parts[-1].replace(".dist-info", ""))
 | |
|     package_bytes = "".join(sorted(packages)).encode("utf8")
 | |
|     return hashlib.md5sum(package_bytes).hexdigest()
 | |
| 
 | |
| 
 | |
| def get_env_hash(env: Dict[str, str]) -> str:
 | |
|     """Construct a hash of the environment variables that will be passed into
 | |
|     the commands.
 | |
| 
 | |
|     Values in the env dict may be references to the current os.environ, using
 | |
|     the syntax $ENV_VAR to mean os.environ[ENV_VAR]
 | |
|     """
 | |
|     env_vars = {}
 | |
|     for key, value in env.items():
 | |
|         if value.startswith("$"):
 | |
|             env_vars[key] = os.environ.get(value[1:], "")
 | |
|         else:
 | |
|             env_vars[key] = value
 | |
|     return get_hash(env_vars)
 |