mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 12:20:20 +03:00
Merge remote-tracking branch 'explosion/feature/weasel' into feat/import-weasel
This commit is contained in:
commit
2d294e6eea
|
@ -3,5 +3,3 @@ __title__ = "spacy"
|
||||||
__version__ = "3.5.0"
|
__version__ = "3.5.0"
|
||||||
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
||||||
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
||||||
__projects__ = "https://github.com/explosion/projects"
|
|
||||||
__projects_branch__ = "v3"
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from click import NoSuchOption
|
||||||
from click.parser import split_arg_string
|
from click.parser import split_arg_string
|
||||||
from typer.main import get_command
|
from typer.main import get_command
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from thinc.api import Config, ConfigValidationError, require_gpu
|
from thinc.api import ConfigValidationError, require_gpu
|
||||||
from thinc.util import gpu_is_available
|
from thinc.util import gpu_is_available
|
||||||
from configparser import InterpolationError
|
from configparser import InterpolationError
|
||||||
import os
|
import os
|
||||||
|
@ -19,10 +19,7 @@ import os
|
||||||
from weasel import app as project_cli
|
from weasel import app as project_cli
|
||||||
|
|
||||||
from ..compat import Literal
|
from ..compat import Literal
|
||||||
from ..schemas import ProjectConfigSchema, validate
|
from ..util import import_file, run_command, registry, logger, ENV_VARS
|
||||||
from ..util import import_file, run_command, make_tempdir, registry, logger
|
|
||||||
from ..util import is_compatible_version, SimpleFrozenDict, ENV_VARS
|
|
||||||
from .. import about
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathy import FluidPath # noqa: F401
|
from pathy import FluidPath # noqa: F401
|
||||||
|
@ -32,7 +29,6 @@ SDIST_SUFFIX = ".tar.gz"
|
||||||
WHEEL_SUFFIX = "-py3-none-any.whl"
|
WHEEL_SUFFIX = "-py3-none-any.whl"
|
||||||
|
|
||||||
PROJECT_FILE = "project.yml"
|
PROJECT_FILE = "project.yml"
|
||||||
PROJECT_LOCK = "project.lock"
|
|
||||||
COMMAND = "python -m spacy"
|
COMMAND = "python -m spacy"
|
||||||
NAME = "spacy"
|
NAME = "spacy"
|
||||||
HELP = """spaCy Command-line Interface
|
HELP = """spaCy Command-line Interface
|
||||||
|
@ -136,148 +132,6 @@ def _parse_override(value: Any) -> Any:
|
||||||
return str(value)
|
return str(value)
|
||||||
|
|
||||||
|
|
||||||
def load_project_config(
|
|
||||||
path: Path, interpolate: bool = True, overrides: Dict[str, Any] = SimpleFrozenDict()
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Load the project.yml file from a directory and validate it. Also make
|
|
||||||
sure that all directories defined in the config exist.
|
|
||||||
|
|
||||||
path (Path): The path to the project directory.
|
|
||||||
interpolate (bool): Whether to substitute project variables.
|
|
||||||
overrides (Dict[str, Any]): Optional config overrides.
|
|
||||||
RETURNS (Dict[str, Any]): The loaded project.yml.
|
|
||||||
"""
|
|
||||||
config_path = path / PROJECT_FILE
|
|
||||||
if not config_path.exists():
|
|
||||||
msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1)
|
|
||||||
invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct."
|
|
||||||
try:
|
|
||||||
config = srsly.read_yaml(config_path)
|
|
||||||
except ValueError as e:
|
|
||||||
msg.fail(invalid_err, e, exits=1)
|
|
||||||
errors = validate(ProjectConfigSchema, config)
|
|
||||||
if errors:
|
|
||||||
msg.fail(invalid_err)
|
|
||||||
print("\n".join(errors))
|
|
||||||
sys.exit(1)
|
|
||||||
validate_project_version(config)
|
|
||||||
validate_project_commands(config)
|
|
||||||
if interpolate:
|
|
||||||
err = f"{PROJECT_FILE} validation error"
|
|
||||||
with show_validation_error(title=err, hint_fill=False):
|
|
||||||
config = substitute_project_variables(config, overrides)
|
|
||||||
# Make sure directories defined in config exist
|
|
||||||
for subdir in config.get("directories", []):
|
|
||||||
dir_path = path / subdir
|
|
||||||
if not dir_path.exists():
|
|
||||||
dir_path.mkdir(parents=True)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def substitute_project_variables(
|
|
||||||
config: Dict[str, Any],
|
|
||||||
overrides: Dict[str, Any] = SimpleFrozenDict(),
|
|
||||||
key: str = "vars",
|
|
||||||
env_key: str = "env",
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
"""Interpolate variables in the project file using the config system.
|
|
||||||
|
|
||||||
config (Dict[str, Any]): The project config.
|
|
||||||
overrides (Dict[str, Any]): Optional config overrides.
|
|
||||||
key (str): Key containing variables in project config.
|
|
||||||
env_key (str): Key containing environment variable mapping in project config.
|
|
||||||
RETURNS (Dict[str, Any]): The interpolated project config.
|
|
||||||
"""
|
|
||||||
config.setdefault(key, {})
|
|
||||||
config.setdefault(env_key, {})
|
|
||||||
# Substitute references to env vars with their values
|
|
||||||
for config_var, env_var in config[env_key].items():
|
|
||||||
config[env_key][config_var] = _parse_override(os.environ.get(env_var, ""))
|
|
||||||
# Need to put variables in the top scope again so we can have a top-level
|
|
||||||
# section "project" (otherwise, a list of commands in the top scope wouldn't)
|
|
||||||
# be allowed by Thinc's config system
|
|
||||||
cfg = Config({"project": config, key: config[key], env_key: config[env_key]})
|
|
||||||
cfg = Config().from_str(cfg.to_str(), overrides=overrides)
|
|
||||||
interpolated = cfg.interpolate()
|
|
||||||
return dict(interpolated["project"])
|
|
||||||
|
|
||||||
|
|
||||||
def validate_project_version(config: Dict[str, Any]) -> None:
|
|
||||||
"""If the project defines a compatible spaCy version range, chec that it's
|
|
||||||
compatible with the current version of spaCy.
|
|
||||||
|
|
||||||
config (Dict[str, Any]): The loaded config.
|
|
||||||
"""
|
|
||||||
spacy_version = config.get("spacy_version", None)
|
|
||||||
if spacy_version and not is_compatible_version(about.__version__, spacy_version):
|
|
||||||
err = (
|
|
||||||
f"The {PROJECT_FILE} specifies a spaCy version range ({spacy_version}) "
|
|
||||||
f"that's not compatible with the version of spaCy you're running "
|
|
||||||
f"({about.__version__}). You can edit version requirement in the "
|
|
||||||
f"{PROJECT_FILE} to load it, but the project may not run as expected."
|
|
||||||
)
|
|
||||||
msg.fail(err, exits=1)
|
|
||||||
|
|
||||||
|
|
||||||
def validate_project_commands(config: Dict[str, Any]) -> None:
|
|
||||||
"""Check that project commands and workflows are valid, don't contain
|
|
||||||
duplicates, don't clash and only refer to commands that exist.
|
|
||||||
|
|
||||||
config (Dict[str, Any]): The loaded config.
|
|
||||||
"""
|
|
||||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
|
||||||
workflows = config.get("workflows", {})
|
|
||||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
|
||||||
if duplicates:
|
|
||||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
|
||||||
msg.fail(err, exits=1)
|
|
||||||
for workflow_name, workflow_steps in workflows.items():
|
|
||||||
if workflow_name in command_names:
|
|
||||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
|
||||||
msg.fail(err, exits=1)
|
|
||||||
for step in workflow_steps:
|
|
||||||
if step not in command_names:
|
|
||||||
msg.fail(
|
|
||||||
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
|
||||||
f"Workflows can only refer to commands defined in the 'commands' "
|
|
||||||
f"section of the {PROJECT_FILE}.",
|
|
||||||
exits=1,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
|
|
||||||
"""Get the hash for a JSON-serializable object.
|
|
||||||
|
|
||||||
data: The data to hash.
|
|
||||||
exclude (Iterable[str]): Top-level keys to exclude if data is a dict.
|
|
||||||
RETURNS (str): The hash.
|
|
||||||
"""
|
|
||||||
if isinstance(data, dict):
|
|
||||||
data = {k: v for k, v in data.items() if k not in exclude}
|
|
||||||
data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8")
|
|
||||||
return hashlib.md5(data_str).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def get_checksum(path: Union[Path, str]) -> str:
|
|
||||||
"""Get the checksum for a file or directory given its file path. If a
|
|
||||||
directory path is provided, this uses all files in that directory.
|
|
||||||
|
|
||||||
path (Union[Path, str]): The file or directory path.
|
|
||||||
RETURNS (str): The checksum.
|
|
||||||
"""
|
|
||||||
path = Path(path)
|
|
||||||
if not (path.is_file() or path.is_dir()):
|
|
||||||
msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
|
|
||||||
if path.is_file():
|
|
||||||
return hashlib.md5(Path(path).read_bytes()).hexdigest()
|
|
||||||
else:
|
|
||||||
# TODO: this is currently pretty slow
|
|
||||||
dir_checksum = hashlib.md5()
|
|
||||||
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
|
||||||
dir_checksum.update(sub_file.read_bytes())
|
|
||||||
return dir_checksum.hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def show_validation_error(
|
def show_validation_error(
|
||||||
file_path: Optional[Union[str, Path]] = None,
|
file_path: Optional[Union[str, Path]] = None,
|
||||||
|
@ -335,166 +189,10 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
||||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||||
|
|
||||||
|
|
||||||
def upload_file(src: Path, dest: Union[str, "FluidPath"]) -> None:
|
|
||||||
"""Upload a file.
|
|
||||||
|
|
||||||
src (Path): The source path.
|
|
||||||
url (str): The destination URL to upload to.
|
|
||||||
"""
|
|
||||||
import smart_open
|
|
||||||
|
|
||||||
# Create parent directories for local paths
|
|
||||||
if isinstance(dest, Path):
|
|
||||||
if not dest.parent.exists():
|
|
||||||
dest.parent.mkdir(parents=True)
|
|
||||||
|
|
||||||
dest = str(dest)
|
|
||||||
with smart_open.open(dest, mode="wb") as output_file:
|
|
||||||
with src.open(mode="rb") as input_file:
|
|
||||||
output_file.write(input_file.read())
|
|
||||||
|
|
||||||
|
|
||||||
def download_file(
|
|
||||||
src: Union[str, "FluidPath"], dest: Path, *, force: bool = False
|
|
||||||
) -> None:
|
|
||||||
"""Download a file using smart_open.
|
|
||||||
|
|
||||||
url (str): The URL of the file.
|
|
||||||
dest (Path): The destination path.
|
|
||||||
force (bool): Whether to force download even if file exists.
|
|
||||||
If False, the download will be skipped.
|
|
||||||
"""
|
|
||||||
import smart_open
|
|
||||||
|
|
||||||
if dest.exists() and not force:
|
|
||||||
return None
|
|
||||||
src = str(src)
|
|
||||||
with smart_open.open(src, mode="rb", compression="disable") as input_file:
|
|
||||||
with dest.open(mode="wb") as output_file:
|
|
||||||
shutil.copyfileobj(input_file, output_file)
|
|
||||||
|
|
||||||
|
|
||||||
def ensure_pathy(path):
|
|
||||||
"""Temporary helper to prevent importing Pathy globally (which can cause
|
|
||||||
slow and annoying Google Cloud warning)."""
|
|
||||||
from pathy import Pathy # noqa: F811
|
|
||||||
|
|
||||||
return Pathy.fluid(path)
|
|
||||||
|
|
||||||
|
|
||||||
def git_checkout(
|
|
||||||
repo: str, subpath: str, dest: Path, *, branch: str = "master", sparse: bool = False
|
|
||||||
):
|
|
||||||
git_version = get_git_version()
|
|
||||||
if dest.exists():
|
|
||||||
msg.fail("Destination of checkout must not exist", exits=1)
|
|
||||||
if not dest.parent.exists():
|
|
||||||
msg.fail("Parent of destination of checkout must exist", exits=1)
|
|
||||||
if sparse and git_version >= (2, 22):
|
|
||||||
return git_sparse_checkout(repo, subpath, dest, branch)
|
|
||||||
elif sparse:
|
|
||||||
# Only show warnings if the user explicitly wants sparse checkout but
|
|
||||||
# the Git version doesn't support it
|
|
||||||
err_old = (
|
|
||||||
f"You're running an old version of Git (v{git_version[0]}.{git_version[1]}) "
|
|
||||||
f"that doesn't fully support sparse checkout yet."
|
|
||||||
)
|
|
||||||
err_unk = "You're running an unknown version of Git, so sparse checkout has been disabled."
|
|
||||||
msg.warn(
|
|
||||||
f"{err_unk if git_version == (0, 0) else err_old} "
|
|
||||||
f"This means that more files than necessary may be downloaded "
|
|
||||||
f"temporarily. To only download the files needed, make sure "
|
|
||||||
f"you're using Git v2.22 or above."
|
|
||||||
)
|
|
||||||
with make_tempdir() as tmp_dir:
|
|
||||||
cmd = f"git -C {tmp_dir} clone {repo} . -b {branch}"
|
|
||||||
run_command(cmd, capture=True)
|
|
||||||
# We need Path(name) to make sure we also support subdirectories
|
|
||||||
try:
|
|
||||||
source_path = tmp_dir / Path(subpath)
|
|
||||||
if not is_subpath_of(tmp_dir, source_path):
|
|
||||||
err = f"'{subpath}' is a path outside of the cloned repository."
|
|
||||||
msg.fail(err, repo, exits=1)
|
|
||||||
shutil.copytree(str(source_path), str(dest))
|
|
||||||
except FileNotFoundError:
|
|
||||||
err = f"Can't clone {subpath}. Make sure the directory exists in the repo (branch '{branch}')"
|
|
||||||
msg.fail(err, repo, exits=1)
|
|
||||||
|
|
||||||
|
|
||||||
def git_sparse_checkout(repo, subpath, dest, branch):
|
|
||||||
# We're using Git, partial clone and sparse checkout to
|
|
||||||
# only clone the files we need
|
|
||||||
# This ends up being RIDICULOUS. omg.
|
|
||||||
# So, every tutorial and SO post talks about 'sparse checkout'...But they
|
|
||||||
# go and *clone* the whole repo. Worthless. And cloning part of a repo
|
|
||||||
# turns out to be completely broken. The only way to specify a "path" is..
|
|
||||||
# a path *on the server*? The contents of which, specifies the paths. Wat.
|
|
||||||
# Obviously this is hopelessly broken and insecure, because you can query
|
|
||||||
# arbitrary paths on the server! So nobody enables this.
|
|
||||||
# What we have to do is disable *all* files. We could then just checkout
|
|
||||||
# the path, and it'd "work", but be hopelessly slow...Because it goes and
|
|
||||||
# transfers every missing object one-by-one. So the final piece is that we
|
|
||||||
# need to use some weird git internals to fetch the missings in bulk, and
|
|
||||||
# *that* we can do by path.
|
|
||||||
# We're using Git and sparse checkout to only clone the files we need
|
|
||||||
with make_tempdir() as tmp_dir:
|
|
||||||
# This is the "clone, but don't download anything" part.
|
|
||||||
cmd = (
|
|
||||||
f"git clone {repo} {tmp_dir} --no-checkout --depth 1 "
|
|
||||||
f"-b {branch} --filter=blob:none"
|
|
||||||
)
|
|
||||||
run_command(cmd)
|
|
||||||
# Now we need to find the missing filenames for the subpath we want.
|
|
||||||
# Looking for this 'rev-list' command in the git --help? Hah.
|
|
||||||
cmd = f"git -C {tmp_dir} rev-list --objects --all --missing=print -- {subpath}"
|
|
||||||
ret = run_command(cmd, capture=True)
|
|
||||||
git_repo = _http_to_git(repo)
|
|
||||||
# Now pass those missings into another bit of git internals
|
|
||||||
missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
|
|
||||||
if not missings:
|
|
||||||
err = (
|
|
||||||
f"Could not find any relevant files for '{subpath}'. "
|
|
||||||
f"Did you specify a correct and complete path within repo '{repo}' "
|
|
||||||
f"and branch {branch}?"
|
|
||||||
)
|
|
||||||
msg.fail(err, exits=1)
|
|
||||||
cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}"
|
|
||||||
run_command(cmd, capture=True)
|
|
||||||
# And finally, we can checkout our subpath
|
|
||||||
cmd = f"git -C {tmp_dir} checkout {branch} {subpath}"
|
|
||||||
run_command(cmd, capture=True)
|
|
||||||
|
|
||||||
# Get a subdirectory of the cloned path, if appropriate
|
|
||||||
source_path = tmp_dir / Path(subpath)
|
|
||||||
if not is_subpath_of(tmp_dir, source_path):
|
|
||||||
err = f"'{subpath}' is a path outside of the cloned repository."
|
|
||||||
msg.fail(err, repo, exits=1)
|
|
||||||
|
|
||||||
shutil.move(str(source_path), str(dest))
|
|
||||||
|
|
||||||
|
|
||||||
def git_repo_branch_exists(repo: str, branch: str) -> bool:
|
|
||||||
"""Uses 'git ls-remote' to check if a repository and branch exists
|
|
||||||
|
|
||||||
repo (str): URL to get repo.
|
|
||||||
branch (str): Branch on repo to check.
|
|
||||||
RETURNS (bool): True if repo:branch exists.
|
|
||||||
"""
|
|
||||||
get_git_version()
|
|
||||||
cmd = f"git ls-remote {repo} {branch}"
|
|
||||||
# We might be tempted to use `--exit-code` with `git ls-remote`, but
|
|
||||||
# `run_command` handles the `returncode` for us, so we'll rely on
|
|
||||||
# the fact that stdout returns '' if the requested branch doesn't exist
|
|
||||||
ret = run_command(cmd, capture=True)
|
|
||||||
exists = ret.stdout != ""
|
|
||||||
return exists
|
|
||||||
|
|
||||||
|
|
||||||
def get_git_version(
|
def get_git_version(
|
||||||
error: str = "Could not run 'git'. Make sure it's installed and the executable is available.",
|
error: str = "Could not run 'git'. Make sure it's installed and the executable is available.",
|
||||||
) -> Tuple[int, int]:
|
) -> Tuple[int, int]:
|
||||||
"""Get the version of git and raise an error if calling 'git --version' fails.
|
"""Get the version of git and raise an error if calling 'git --version' fails.
|
||||||
|
|
||||||
error (str): The error message to show.
|
error (str): The error message to show.
|
||||||
RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns
|
RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns
|
||||||
(0, 0) if the version couldn't be determined.
|
(0, 0) if the version couldn't be determined.
|
||||||
|
@ -510,30 +208,6 @@ def get_git_version(
|
||||||
return int(version[0]), int(version[1])
|
return int(version[0]), int(version[1])
|
||||||
|
|
||||||
|
|
||||||
def _http_to_git(repo: str) -> str:
|
|
||||||
if repo.startswith("http://"):
|
|
||||||
repo = repo.replace(r"http://", r"https://")
|
|
||||||
if repo.startswith(r"https://"):
|
|
||||||
repo = repo.replace("https://", "git@").replace("/", ":", 1)
|
|
||||||
if repo.endswith("/"):
|
|
||||||
repo = repo[:-1]
|
|
||||||
repo = f"{repo}.git"
|
|
||||||
return repo
|
|
||||||
|
|
||||||
|
|
||||||
def is_subpath_of(parent, child):
|
|
||||||
"""
|
|
||||||
Check whether `child` is a path contained within `parent`.
|
|
||||||
"""
|
|
||||||
# Based on https://stackoverflow.com/a/37095733 .
|
|
||||||
|
|
||||||
# In Python 3.9, the `Path.is_relative_to()` method will supplant this, so
|
|
||||||
# we can stop using crusty old os.path functions.
|
|
||||||
parent_realpath = os.path.realpath(parent)
|
|
||||||
child_realpath = os.path.realpath(child)
|
|
||||||
return os.path.commonpath([parent_realpath, child_realpath]) == parent_realpath
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]:
|
def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]:
|
||||||
...
|
...
|
||||||
|
|
|
@ -553,8 +553,6 @@ class Errors(metaclass=ErrorsWithCodes):
|
||||||
"vectors, not {mode} vectors.")
|
"vectors, not {mode} vectors.")
|
||||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||||
"but found value of '{val}'.")
|
"but found value of '{val}'.")
|
||||||
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
|
||||||
"traversal.")
|
|
||||||
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
||||||
"not permitted in factory names.")
|
"not permitted in factory names.")
|
||||||
E854 = ("Unable to set doc.ents. Check that the 'ents_filter' does not "
|
E854 = ("Unable to set doc.ents. Check that the 'ents_filter' does not "
|
||||||
|
|
|
@ -443,66 +443,6 @@ CONFIG_SCHEMAS = {
|
||||||
"initialize": ConfigSchemaInit,
|
"initialize": ConfigSchemaInit,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
# Project config Schema
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfigAssetGitItem(BaseModel):
|
|
||||||
# fmt: off
|
|
||||||
repo: StrictStr = Field(..., title="URL of Git repo to download from")
|
|
||||||
path: StrictStr = Field(..., title="File path or sub-directory to download (used for sparse checkout)")
|
|
||||||
branch: StrictStr = Field("master", title="Branch to clone from")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfigAssetURL(BaseModel):
|
|
||||||
# fmt: off
|
|
||||||
dest: StrictStr = Field(..., title="Destination of downloaded asset")
|
|
||||||
url: Optional[StrictStr] = Field(None, title="URL of asset")
|
|
||||||
checksum: Optional[str] = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
|
|
||||||
description: StrictStr = Field("", title="Description of asset")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfigAssetGit(BaseModel):
|
|
||||||
# fmt: off
|
|
||||||
git: ProjectConfigAssetGitItem = Field(..., title="Git repo information")
|
|
||||||
checksum: Optional[str] = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
|
|
||||||
description: Optional[StrictStr] = Field(None, title="Description of asset")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfigCommand(BaseModel):
|
|
||||||
# fmt: off
|
|
||||||
name: StrictStr = Field(..., title="Name of command")
|
|
||||||
help: Optional[StrictStr] = Field(None, title="Command description")
|
|
||||||
script: List[StrictStr] = Field([], title="List of CLI commands to run, in order")
|
|
||||||
deps: List[StrictStr] = Field([], title="File dependencies required by this command")
|
|
||||||
outputs: List[StrictStr] = Field([], title="Outputs produced by this command")
|
|
||||||
outputs_no_cache: List[StrictStr] = Field([], title="Outputs not tracked by DVC (DVC only)")
|
|
||||||
no_skip: bool = Field(False, title="Never skip this command, even if nothing changed")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
title = "A single named command specified in a project config"
|
|
||||||
extra = "forbid"
|
|
||||||
|
|
||||||
|
|
||||||
class ProjectConfigSchema(BaseModel):
|
|
||||||
# fmt: off
|
|
||||||
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
|
||||||
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
|
|
||||||
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
|
||||||
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
|
||||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
|
|
||||||
title: Optional[str] = Field(None, title="Project title")
|
|
||||||
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
|
|
||||||
# fmt: on
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
title = "Schema for project configuration file"
|
|
||||||
|
|
||||||
|
|
||||||
# Recommendations for init config workflows
|
# Recommendations for init config workflows
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ import os
|
||||||
import math
|
import math
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from typing import Tuple, List, Dict, Any
|
from typing import Tuple, List, Dict, Any
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
|
@ -17,11 +16,8 @@ from weasel.cli.run import _check_requirements
|
||||||
|
|
||||||
from spacy import about
|
from spacy import about
|
||||||
from spacy.cli import info
|
from spacy.cli import info
|
||||||
from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
|
from spacy.cli._util import walk_directory
|
||||||
from spacy.cli._util import parse_config_overrides, string_to_list
|
from spacy.cli._util import parse_config_overrides, string_to_list
|
||||||
from spacy.cli._util import substitute_project_variables
|
|
||||||
from spacy.cli._util import validate_project_commands
|
|
||||||
from spacy.cli._util import upload_file, download_file
|
|
||||||
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
||||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||||
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
||||||
|
@ -39,7 +35,7 @@ from spacy.cli.find_threshold import find_threshold
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.lang.nl import Dutch
|
from spacy.lang.nl import Dutch
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
from spacy.schemas import RecommendationSchema
|
||||||
from spacy.tokens import Doc, DocBin
|
from spacy.tokens import Doc, DocBin
|
||||||
from spacy.tokens.span import Span
|
from spacy.tokens.span import Span
|
||||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||||
|
@ -125,25 +121,6 @@ def test_issue7055():
|
||||||
assert "model" in filled_cfg["components"]["ner"]
|
assert "model" in filled_cfg["components"]["ner"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.issue(11235)
|
|
||||||
def test_issue11235():
|
|
||||||
"""
|
|
||||||
Test that the cli handles interpolation in the directory names correctly when loading project config.
|
|
||||||
"""
|
|
||||||
lang_var = "en"
|
|
||||||
variables = {"lang": lang_var}
|
|
||||||
commands = [{"name": "x", "script": ["hello ${vars.lang}"]}]
|
|
||||||
directories = ["cfg", "${vars.lang}_model"]
|
|
||||||
project = {"commands": commands, "vars": variables, "directories": directories}
|
|
||||||
with make_tempdir() as d:
|
|
||||||
srsly.write_yaml(d / "project.yml", project)
|
|
||||||
cfg = load_project_config(d)
|
|
||||||
# Check that the directories are interpolated and created correctly
|
|
||||||
assert os.path.exists(d / "cfg")
|
|
||||||
assert os.path.exists(d / f"{lang_var}_model")
|
|
||||||
assert cfg["commands"][0]["script"][0] == f"hello {lang_var}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_info():
|
def test_cli_info():
|
||||||
nlp = Dutch()
|
nlp = Dutch()
|
||||||
nlp.add_pipe("textcat")
|
nlp.add_pipe("textcat")
|
||||||
|
@ -370,136 +347,6 @@ def test_cli_converters_conll_ner_to_docs():
|
||||||
assert ent.text in ["New York City", "London"]
|
assert ent.text in ["New York City", "London"]
|
||||||
|
|
||||||
|
|
||||||
def test_project_config_validation_full():
|
|
||||||
config = {
|
|
||||||
"vars": {"some_var": 20},
|
|
||||||
"directories": ["assets", "configs", "corpus", "scripts", "training"],
|
|
||||||
"assets": [
|
|
||||||
{
|
|
||||||
"dest": "x",
|
|
||||||
"extra": True,
|
|
||||||
"url": "https://example.com",
|
|
||||||
"checksum": "63373dd656daa1fd3043ce166a59474c",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dest": "y",
|
|
||||||
"git": {
|
|
||||||
"repo": "https://github.com/example/repo",
|
|
||||||
"branch": "develop",
|
|
||||||
"path": "y",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"dest": "z",
|
|
||||||
"extra": False,
|
|
||||||
"url": "https://example.com",
|
|
||||||
"checksum": "63373dd656daa1fd3043ce166a59474c",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"commands": [
|
|
||||||
{
|
|
||||||
"name": "train",
|
|
||||||
"help": "Train a model",
|
|
||||||
"script": ["python -m spacy train config.cfg -o training"],
|
|
||||||
"deps": ["config.cfg", "corpus/training.spcy"],
|
|
||||||
"outputs": ["training/model-best"],
|
|
||||||
},
|
|
||||||
{"name": "test", "script": ["pytest", "custom.py"], "no_skip": True},
|
|
||||||
],
|
|
||||||
"workflows": {"all": ["train", "test"], "train": ["train"]},
|
|
||||||
}
|
|
||||||
errors = validate(ProjectConfigSchema, config)
|
|
||||||
assert not errors
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"config",
|
|
||||||
[
|
|
||||||
{"commands": [{"name": "a"}, {"name": "a"}]},
|
|
||||||
{"commands": [{"name": "a"}], "workflows": {"a": []}},
|
|
||||||
{"commands": [{"name": "a"}], "workflows": {"b": ["c"]}},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_project_config_validation1(config):
|
|
||||||
with pytest.raises(SystemExit):
|
|
||||||
validate_project_commands(config)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"config,n_errors",
|
|
||||||
[
|
|
||||||
({"commands": {"a": []}}, 1),
|
|
||||||
({"commands": [{"help": "..."}]}, 1),
|
|
||||||
({"commands": [{"name": "a", "extra": "b"}]}, 1),
|
|
||||||
({"commands": [{"extra": "b"}]}, 2),
|
|
||||||
({"commands": [{"name": "a", "deps": [123]}]}, 1),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_project_config_validation2(config, n_errors):
|
|
||||||
errors = validate(ProjectConfigSchema, config)
|
|
||||||
assert len(errors) == n_errors
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"int_value",
|
|
||||||
[10, pytest.param("10", marks=pytest.mark.xfail)],
|
|
||||||
)
|
|
||||||
def test_project_config_interpolation(int_value):
|
|
||||||
variables = {"a": int_value, "b": {"c": "foo", "d": True}}
|
|
||||||
commands = [
|
|
||||||
{"name": "x", "script": ["hello ${vars.a} ${vars.b.c}"]},
|
|
||||||
{"name": "y", "script": ["${vars.b.c} ${vars.b.d}"]},
|
|
||||||
]
|
|
||||||
project = {"commands": commands, "vars": variables}
|
|
||||||
with make_tempdir() as d:
|
|
||||||
srsly.write_yaml(d / "project.yml", project)
|
|
||||||
cfg = load_project_config(d)
|
|
||||||
assert type(cfg) == dict
|
|
||||||
assert type(cfg["commands"]) == list
|
|
||||||
assert cfg["commands"][0]["script"][0] == "hello 10 foo"
|
|
||||||
assert cfg["commands"][1]["script"][0] == "foo true"
|
|
||||||
commands = [{"name": "x", "script": ["hello ${vars.a} ${vars.b.e}"]}]
|
|
||||||
project = {"commands": commands, "vars": variables}
|
|
||||||
with pytest.raises(ConfigValidationError):
|
|
||||||
substitute_project_variables(project)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"greeting",
|
|
||||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
|
||||||
)
|
|
||||||
def test_project_config_interpolation_override(greeting):
|
|
||||||
variables = {"a": "world"}
|
|
||||||
commands = [
|
|
||||||
{"name": "x", "script": ["hello ${vars.a}"]},
|
|
||||||
]
|
|
||||||
overrides = {"vars.a": greeting}
|
|
||||||
project = {"commands": commands, "vars": variables}
|
|
||||||
with make_tempdir() as d:
|
|
||||||
srsly.write_yaml(d / "project.yml", project)
|
|
||||||
cfg = load_project_config(d, overrides=overrides)
|
|
||||||
assert type(cfg) == dict
|
|
||||||
assert type(cfg["commands"]) == list
|
|
||||||
assert cfg["commands"][0]["script"][0] == f"hello {greeting}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_project_config_interpolation_env():
|
|
||||||
variables = {"a": 10}
|
|
||||||
env_var = "SPACY_TEST_FOO"
|
|
||||||
env_vars = {"foo": env_var}
|
|
||||||
commands = [{"name": "x", "script": ["hello ${vars.a} ${env.foo}"]}]
|
|
||||||
project = {"commands": commands, "vars": variables, "env": env_vars}
|
|
||||||
with make_tempdir() as d:
|
|
||||||
srsly.write_yaml(d / "project.yml", project)
|
|
||||||
cfg = load_project_config(d)
|
|
||||||
assert cfg["commands"][0]["script"][0] == "hello 10 "
|
|
||||||
os.environ[env_var] = "123"
|
|
||||||
with make_tempdir() as d:
|
|
||||||
srsly.write_yaml(d / "project.yml", project)
|
|
||||||
cfg = load_project_config(d)
|
|
||||||
assert cfg["commands"][0]["script"][0] == "hello 10 123"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"args,expected",
|
"args,expected",
|
||||||
[
|
[
|
||||||
|
@ -709,21 +556,6 @@ def test_get_third_party_dependencies():
|
||||||
get_third_party_dependencies(nlp.config)
|
get_third_party_dependencies(nlp.config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"parent,child,expected",
|
|
||||||
[
|
|
||||||
("/tmp", "/tmp", True),
|
|
||||||
("/tmp", "/", False),
|
|
||||||
("/tmp", "/tmp/subdir", True),
|
|
||||||
("/tmp", "/tmpdir", False),
|
|
||||||
("/tmp", "/tmp/subdir/..", True),
|
|
||||||
("/tmp", "/tmp/..", False),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_is_subpath_of(parent, child, expected):
|
|
||||||
assert is_subpath_of(parent, child) == expected
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.slow
|
@pytest.mark.slow
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"factory_name,pipe_name",
|
"factory_name,pipe_name",
|
||||||
|
@ -968,60 +800,6 @@ def test_applycli_user_data():
|
||||||
assert result[0]._.ext == val
|
assert result[0]._.ext == val
|
||||||
|
|
||||||
|
|
||||||
def test_local_remote_storage():
|
|
||||||
with make_tempdir() as d:
|
|
||||||
filename = "a.txt"
|
|
||||||
|
|
||||||
content_hashes = ("aaaa", "cccc", "bbbb")
|
|
||||||
for i, content_hash in enumerate(content_hashes):
|
|
||||||
# make sure that each subsequent file has a later timestamp
|
|
||||||
if i > 0:
|
|
||||||
time.sleep(1)
|
|
||||||
content = f"{content_hash} content"
|
|
||||||
loc_file = d / "root" / filename
|
|
||||||
if not loc_file.parent.exists():
|
|
||||||
loc_file.parent.mkdir(parents=True)
|
|
||||||
with loc_file.open(mode="w") as file_:
|
|
||||||
file_.write(content)
|
|
||||||
|
|
||||||
# push first version to remote storage
|
|
||||||
remote = RemoteStorage(d / "root", str(d / "remote"))
|
|
||||||
remote.push(filename, "aaaa", content_hash)
|
|
||||||
|
|
||||||
# retrieve with full hashes
|
|
||||||
loc_file.unlink()
|
|
||||||
remote.pull(filename, command_hash="aaaa", content_hash=content_hash)
|
|
||||||
with loc_file.open(mode="r") as file_:
|
|
||||||
assert file_.read() == content
|
|
||||||
|
|
||||||
# retrieve with command hash
|
|
||||||
loc_file.unlink()
|
|
||||||
remote.pull(filename, command_hash="aaaa")
|
|
||||||
with loc_file.open(mode="r") as file_:
|
|
||||||
assert file_.read() == content
|
|
||||||
|
|
||||||
# retrieve with content hash
|
|
||||||
loc_file.unlink()
|
|
||||||
remote.pull(filename, content_hash=content_hash)
|
|
||||||
with loc_file.open(mode="r") as file_:
|
|
||||||
assert file_.read() == content
|
|
||||||
|
|
||||||
# retrieve with no hashes
|
|
||||||
loc_file.unlink()
|
|
||||||
remote.pull(filename)
|
|
||||||
with loc_file.open(mode="r") as file_:
|
|
||||||
assert file_.read() == content
|
|
||||||
|
|
||||||
|
|
||||||
def test_local_remote_storage_pull_missing():
|
|
||||||
# pulling from a non-existent remote pulls nothing gracefully
|
|
||||||
with make_tempdir() as d:
|
|
||||||
filename = "a.txt"
|
|
||||||
remote = RemoteStorage(d / "root", str(d / "remote"))
|
|
||||||
assert remote.pull(filename, command_hash="aaaa") is None
|
|
||||||
assert remote.pull(filename) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_cli_find_threshold(capsys):
|
def test_cli_find_threshold(capsys):
|
||||||
def make_examples(nlp: Language) -> List[Example]:
|
def make_examples(nlp: Language) -> List[Example]:
|
||||||
docs: List[Example] = []
|
docs: List[Example] = []
|
||||||
|
@ -1132,63 +910,6 @@ def test_cli_find_threshold(capsys):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"reqs,output",
|
|
||||||
[
|
|
||||||
[
|
|
||||||
"""
|
|
||||||
spacy
|
|
||||||
|
|
||||||
# comment
|
|
||||||
|
|
||||||
thinc""",
|
|
||||||
(False, False),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"""# comment
|
|
||||||
--some-flag
|
|
||||||
spacy""",
|
|
||||||
(False, False),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"""# comment
|
|
||||||
--some-flag
|
|
||||||
spacy; python_version >= '3.6'""",
|
|
||||||
(False, False),
|
|
||||||
],
|
|
||||||
[
|
|
||||||
"""# comment
|
|
||||||
spacyunknowndoesnotexist12345""",
|
|
||||||
(True, False),
|
|
||||||
],
|
|
||||||
],
|
|
||||||
)
|
|
||||||
def test_project_check_requirements(reqs, output):
|
|
||||||
import pkg_resources
|
|
||||||
|
|
||||||
# excessive guard against unlikely package name
|
|
||||||
try:
|
|
||||||
pkg_resources.require("spacyunknowndoesnotexist12345")
|
|
||||||
except pkg_resources.DistributionNotFound:
|
|
||||||
assert output == _check_requirements([req.strip() for req in reqs.split("\n")])
|
|
||||||
|
|
||||||
|
|
||||||
def test_upload_download_local_file():
|
|
||||||
with make_tempdir() as d1, make_tempdir() as d2:
|
|
||||||
filename = "f.txt"
|
|
||||||
content = "content"
|
|
||||||
local_file = d1 / filename
|
|
||||||
remote_file = d2 / filename
|
|
||||||
with local_file.open(mode="w") as file_:
|
|
||||||
file_.write(content)
|
|
||||||
upload_file(local_file, remote_file)
|
|
||||||
local_file.unlink()
|
|
||||||
download_file(remote_file, local_file)
|
|
||||||
with local_file.open(mode="r") as file_:
|
|
||||||
assert file_.read() == content
|
|
||||||
|
|
||||||
|
|
||||||
def test_walk_directory():
|
def test_walk_directory():
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
files = [
|
files = [
|
||||||
|
|
|
@ -78,7 +78,6 @@ logger.addHandler(logger_stream_handler)
|
||||||
|
|
||||||
class ENV_VARS:
|
class ENV_VARS:
|
||||||
CONFIG_OVERRIDES = "SPACY_CONFIG_OVERRIDES"
|
CONFIG_OVERRIDES = "SPACY_CONFIG_OVERRIDES"
|
||||||
PROJECT_USE_GIT_VERSION = "SPACY_PROJECT_USE_GIT_VERSION"
|
|
||||||
|
|
||||||
|
|
||||||
class registry(thinc.registry):
|
class registry(thinc.registry):
|
||||||
|
@ -951,23 +950,12 @@ def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
|
||||||
|
|
||||||
def split_command(command: str) -> List[str]:
|
def split_command(command: str) -> List[str]:
|
||||||
"""Split a string command using shlex. Handles platform compatibility.
|
"""Split a string command using shlex. Handles platform compatibility.
|
||||||
|
|
||||||
command (str) : The command to split
|
command (str) : The command to split
|
||||||
RETURNS (List[str]): The split command.
|
RETURNS (List[str]): The split command.
|
||||||
"""
|
"""
|
||||||
return shlex.split(command, posix=not is_windows)
|
return shlex.split(command, posix=not is_windows)
|
||||||
|
|
||||||
|
|
||||||
def join_command(command: List[str]) -> str:
|
|
||||||
"""Join a command using shlex. shlex.join is only available for Python 3.8+,
|
|
||||||
so we're using a workaround here.
|
|
||||||
|
|
||||||
command (List[str]): The command to join.
|
|
||||||
RETURNS (str): The joined command
|
|
||||||
"""
|
|
||||||
return " ".join(shlex.quote(cmd) for cmd in command)
|
|
||||||
|
|
||||||
|
|
||||||
def run_command(
|
def run_command(
|
||||||
command: Union[str, List[str]],
|
command: Union[str, List[str]],
|
||||||
*,
|
*,
|
||||||
|
@ -976,7 +964,6 @@ def run_command(
|
||||||
) -> subprocess.CompletedProcess:
|
) -> subprocess.CompletedProcess:
|
||||||
"""Run a command on the command line as a subprocess. If the subprocess
|
"""Run a command on the command line as a subprocess. If the subprocess
|
||||||
returns a non-zero exit code, a system exit is performed.
|
returns a non-zero exit code, a system exit is performed.
|
||||||
|
|
||||||
command (str / List[str]): The command. If provided as a string, the
|
command (str / List[str]): The command. If provided as a string, the
|
||||||
string will be split using shlex.split.
|
string will be split using shlex.split.
|
||||||
stdin (Optional[Any]): stdin to read from or None.
|
stdin (Optional[Any]): stdin to read from or None.
|
||||||
|
@ -1027,7 +1014,6 @@ def run_command(
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def working_dir(path: Union[str, Path]) -> Iterator[Path]:
|
def working_dir(path: Union[str, Path]) -> Iterator[Path]:
|
||||||
"""Change current working directory and returns to previous on exit.
|
"""Change current working directory and returns to previous on exit.
|
||||||
|
|
||||||
path (str / Path): The directory to navigate to.
|
path (str / Path): The directory to navigate to.
|
||||||
YIELDS (Path): The absolute path to the current working directory. This
|
YIELDS (Path): The absolute path to the current working directory. This
|
||||||
should be used if the block needs to perform actions within the working
|
should be used if the block needs to perform actions within the working
|
||||||
|
@ -1046,7 +1032,6 @@ def working_dir(path: Union[str, Path]) -> Iterator[Path]:
|
||||||
def make_tempdir() -> Generator[Path, None, None]:
|
def make_tempdir() -> Generator[Path, None, None]:
|
||||||
"""Execute a block in a temporary directory and remove the directory and
|
"""Execute a block in a temporary directory and remove the directory and
|
||||||
its contents at the end of the with block.
|
its contents at the end of the with block.
|
||||||
|
|
||||||
YIELDS (Path): The path of the temp directory.
|
YIELDS (Path): The path of the temp directory.
|
||||||
"""
|
"""
|
||||||
d = Path(tempfile.mkdtemp())
|
d = Path(tempfile.mkdtemp())
|
||||||
|
@ -1064,15 +1049,6 @@ def make_tempdir() -> Generator[Path, None, None]:
|
||||||
warnings.warn(Warnings.W091.format(dir=d, msg=e))
|
warnings.warn(Warnings.W091.format(dir=d, msg=e))
|
||||||
|
|
||||||
|
|
||||||
def is_cwd(path: Union[Path, str]) -> bool:
|
|
||||||
"""Check whether a path is the current working directory.
|
|
||||||
|
|
||||||
path (Union[Path, str]): The directory path.
|
|
||||||
RETURNS (bool): Whether the path is the current working directory.
|
|
||||||
"""
|
|
||||||
return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower()
|
|
||||||
|
|
||||||
|
|
||||||
def is_in_jupyter() -> bool:
|
def is_in_jupyter() -> bool:
|
||||||
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
||||||
IPython kernel. Mainly used for the displaCy visualizer.
|
IPython kernel. Mainly used for the displaCy visualizer.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user