mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	* Support local filesystem remotes for projects
* Fix support for local filesystem remotes for projects
  * Use `FluidPath` instead of `Pathy` to support both filesystem and
    remote paths
  * Create missing parent directories if required for local filesystem
  * Add a more general `_file_exists` method to support both `Pathy`,
    `Path`, and `smart_open`-compatible URLs
* Add explicit `smart_open` dependency starting with support for
  `compression` flag
* Update `pathy` dependency to exclude older versions that aren't
  compatible with required `smart_open` version
* Update docs to refer to `Pathy` instead of `smart_open` for project
  remotes (technically you can still push to any `smart_open`-compatible
  path but you can't pull from them)
* Add tests for local filesystem remotes
* Update pathy for general BlobStat sorting
* Add import
* Remove _file_exists since only Pathy remotes are supported
* Format CLI docs
* Clean up merge
		
	
			
		
			
				
	
	
		
			206 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			206 lines
		
	
	
		
			8.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Optional, List, Dict, TYPE_CHECKING
 | |
| import os
 | |
| import site
 | |
| import hashlib
 | |
| import urllib.parse
 | |
| import tarfile
 | |
| from pathlib import Path
 | |
| from wasabi import msg
 | |
| 
 | |
| from .._util import get_hash, get_checksum, upload_file, download_file
 | |
| from .._util import ensure_pathy, make_tempdir
 | |
| from ...util import get_minor_version, ENV_VARS, check_bool_env_var
 | |
| from ...git_info import GIT_VERSION
 | |
| from ... import about
 | |
| from ...errors import Errors
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from pathy import FluidPath  # noqa: F401
 | |
| 
 | |
| 
 | |
| class RemoteStorage:
 | |
|     """Push and pull outputs to and from a remote file storage.
 | |
| 
 | |
|     Remotes can be anything that `smart-open` can support: AWS, GCS, file system,
 | |
|     ssh, etc.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, project_root: Path, url: str, *, compression="gz"):
 | |
|         self.root = project_root
 | |
|         self.url = ensure_pathy(url)
 | |
|         self.compression = compression
 | |
| 
 | |
|     def push(self, path: Path, command_hash: str, content_hash: str) -> "FluidPath":
 | |
|         """Compress a file or directory within a project and upload it to a remote
 | |
|         storage. If an object exists at the full URL, nothing is done.
 | |
| 
 | |
|         Within the remote storage, files are addressed by their project path
 | |
|         (url encoded) and two user-supplied hashes, representing their creation
 | |
|         context and their file contents. If the URL already exists, the data is
 | |
|         not uploaded. Paths are archived and compressed prior to upload.
 | |
|         """
 | |
|         loc = self.root / path
 | |
|         if not loc.exists():
 | |
|             raise IOError(f"Cannot push {loc}: does not exist.")
 | |
|         url = self.make_url(path, command_hash, content_hash)
 | |
|         if url.exists():
 | |
|             return url
 | |
|         tmp: Path
 | |
|         with make_tempdir() as tmp:
 | |
|             tar_loc = tmp / self.encode_name(str(path))
 | |
|             mode_string = f"w:{self.compression}" if self.compression else "w"
 | |
|             with tarfile.open(tar_loc, mode=mode_string) as tar_file:
 | |
|                 tar_file.add(str(loc), arcname=str(path))
 | |
|             upload_file(tar_loc, url)
 | |
|         return url
 | |
| 
 | |
|     def pull(
 | |
|         self,
 | |
|         path: Path,
 | |
|         *,
 | |
|         command_hash: Optional[str] = None,
 | |
|         content_hash: Optional[str] = None,
 | |
|     ) -> Optional["FluidPath"]:
 | |
|         """Retrieve a file from the remote cache. If the file already exists,
 | |
|         nothing is done.
 | |
| 
 | |
|         If the command_hash and/or content_hash are specified, only matching
 | |
|         results are returned. If no results are available, an error is raised.
 | |
|         """
 | |
|         dest = self.root / path
 | |
|         if dest.exists():
 | |
|             return None
 | |
|         url = self.find(path, command_hash=command_hash, content_hash=content_hash)
 | |
|         if url is None:
 | |
|             return url
 | |
|         else:
 | |
|             # Make sure the destination exists
 | |
|             if not dest.parent.exists():
 | |
|                 dest.parent.mkdir(parents=True)
 | |
|             tmp: Path
 | |
|             with make_tempdir() as tmp:
 | |
|                 tar_loc = tmp / url.parts[-1]
 | |
|                 download_file(url, tar_loc)
 | |
|                 mode_string = f"r:{self.compression}" if self.compression else "r"
 | |
|                 with tarfile.open(tar_loc, mode=mode_string) as tar_file:
 | |
|                     # This requires that the path is added correctly, relative
 | |
|                     # to root. This is how we set things up in push()
 | |
| 
 | |
|                     # Disallow paths outside the current directory for the tar
 | |
|                     # file (CVE-2007-4559, directory traversal vulnerability)
 | |
|                     def is_within_directory(directory, target):
 | |
|                         abs_directory = os.path.abspath(directory)
 | |
|                         abs_target = os.path.abspath(target)
 | |
|                         prefix = os.path.commonprefix([abs_directory, abs_target])
 | |
|                         return prefix == abs_directory
 | |
| 
 | |
|                     def safe_extract(tar, path):
 | |
|                         for member in tar.getmembers():
 | |
|                             member_path = os.path.join(path, member.name)
 | |
|                             if not is_within_directory(path, member_path):
 | |
|                                 raise ValueError(Errors.E852)
 | |
|                         tar.extractall(path)
 | |
| 
 | |
|                     safe_extract(tar_file, self.root)
 | |
|         return url
 | |
| 
 | |
|     def find(
 | |
|         self,
 | |
|         path: Path,
 | |
|         *,
 | |
|         command_hash: Optional[str] = None,
 | |
|         content_hash: Optional[str] = None,
 | |
|     ) -> Optional["FluidPath"]:
 | |
|         """Find the best matching version of a file within the storage,
 | |
|         or `None` if no match can be found. If both the creation and content hash
 | |
|         are specified, only exact matches will be returned. Otherwise, the most
 | |
|         recent matching file is preferred.
 | |
|         """
 | |
|         name = self.encode_name(str(path))
 | |
|         urls = []
 | |
|         if command_hash is not None and content_hash is not None:
 | |
|             url = self.url / name / command_hash / content_hash
 | |
|             urls = [url] if url.exists() else []
 | |
|         elif command_hash is not None:
 | |
|             if (self.url / name / command_hash).exists():
 | |
|                 urls = list((self.url / name / command_hash).iterdir())
 | |
|         else:
 | |
|             if (self.url / name).exists():
 | |
|                 for sub_dir in (self.url / name).iterdir():
 | |
|                     urls.extend(sub_dir.iterdir())
 | |
|                 if content_hash is not None:
 | |
|                     urls = [url for url in urls if url.parts[-1] == content_hash]
 | |
|         if len(urls) >= 2:
 | |
|             try:
 | |
|                 urls.sort(key=lambda x: x.stat().last_modified)  # type: ignore
 | |
|             except Exception:
 | |
|                 msg.warn(
 | |
|                     "Unable to sort remote files by last modified. The file(s) "
 | |
|                     "pulled from the cache may not be the most recent."
 | |
|                 )
 | |
|         return urls[-1] if urls else None
 | |
| 
 | |
|     def make_url(self, path: Path, command_hash: str, content_hash: str) -> "FluidPath":
 | |
|         """Construct a URL from a subpath, a creation hash and a content hash."""
 | |
|         return self.url / self.encode_name(str(path)) / command_hash / content_hash
 | |
| 
 | |
|     def encode_name(self, name: str) -> str:
 | |
|         """Encode a subpath into a URL-safe name."""
 | |
|         return urllib.parse.quote_plus(name)
 | |
| 
 | |
| 
 | |
| def get_content_hash(loc: Path) -> str:
 | |
|     return get_checksum(loc)
 | |
| 
 | |
| 
 | |
| def get_command_hash(
 | |
|     site_hash: str, env_hash: str, deps: List[Path], cmd: List[str]
 | |
| ) -> str:
 | |
|     """Create a hash representing the execution of a command. This includes the
 | |
|     currently installed packages, whatever environment variables have been marked
 | |
|     as relevant, and the command.
 | |
|     """
 | |
|     if check_bool_env_var(ENV_VARS.PROJECT_USE_GIT_VERSION):
 | |
|         spacy_v = GIT_VERSION
 | |
|     else:
 | |
|         spacy_v = str(get_minor_version(about.__version__) or "")
 | |
|     dep_checksums = [get_checksum(dep) for dep in sorted(deps)]
 | |
|     hashes = [spacy_v, site_hash, env_hash] + dep_checksums
 | |
|     hashes.extend(cmd)
 | |
|     creation_bytes = "".join(hashes).encode("utf8")
 | |
|     return hashlib.md5(creation_bytes).hexdigest()
 | |
| 
 | |
| 
 | |
| def get_site_hash():
 | |
|     """Hash the current Python environment's site-packages contents, including
 | |
|     the name and version of the libraries. The list we're hashing is what
 | |
|     `pip freeze` would output.
 | |
|     """
 | |
|     site_dirs = site.getsitepackages()
 | |
|     if site.ENABLE_USER_SITE:
 | |
|         site_dirs.extend(site.getusersitepackages())
 | |
|     packages = set()
 | |
|     for site_dir in site_dirs:
 | |
|         site_dir = Path(site_dir)
 | |
|         for subpath in site_dir.iterdir():
 | |
|             if subpath.parts[-1].endswith("dist-info"):
 | |
|                 packages.add(subpath.parts[-1].replace(".dist-info", ""))
 | |
|     package_bytes = "".join(sorted(packages)).encode("utf8")
 | |
|     return hashlib.md5sum(package_bytes).hexdigest()
 | |
| 
 | |
| 
 | |
| def get_env_hash(env: Dict[str, str]) -> str:
 | |
|     """Construct a hash of the environment variables that will be passed into
 | |
|     the commands.
 | |
| 
 | |
|     Values in the env dict may be references to the current os.environ, using
 | |
|     the syntax $ENV_VAR to mean os.environ[ENV_VAR]
 | |
|     """
 | |
|     env_vars = {}
 | |
|     for key, value in env.items():
 | |
|         if value.startswith("$"):
 | |
|             env_vars[key] = os.environ.get(value[1:], "")
 | |
|         else:
 | |
|             env_vars[key] = value
 | |
|     return get_hash(env_vars)
 |