mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-03 20:00:21 +03:00
Merge remote-tracking branch 'explosion/feature/weasel' into feat/import-weasel
This commit is contained in:
commit
2d294e6eea
|
@ -3,5 +3,3 @@ __title__ = "spacy"
|
|||
__version__ = "3.5.0"
|
||||
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
|
||||
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
|
||||
__projects__ = "https://github.com/explosion/projects"
|
||||
__projects_branch__ = "v3"
|
||||
|
|
|
@ -11,7 +11,7 @@ from click import NoSuchOption
|
|||
from click.parser import split_arg_string
|
||||
from typer.main import get_command
|
||||
from contextlib import contextmanager
|
||||
from thinc.api import Config, ConfigValidationError, require_gpu
|
||||
from thinc.api import ConfigValidationError, require_gpu
|
||||
from thinc.util import gpu_is_available
|
||||
from configparser import InterpolationError
|
||||
import os
|
||||
|
@ -19,10 +19,7 @@ import os
|
|||
from weasel import app as project_cli
|
||||
|
||||
from ..compat import Literal
|
||||
from ..schemas import ProjectConfigSchema, validate
|
||||
from ..util import import_file, run_command, make_tempdir, registry, logger
|
||||
from ..util import is_compatible_version, SimpleFrozenDict, ENV_VARS
|
||||
from .. import about
|
||||
from ..util import import_file, run_command, registry, logger, ENV_VARS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathy import FluidPath # noqa: F401
|
||||
|
@ -32,7 +29,6 @@ SDIST_SUFFIX = ".tar.gz"
|
|||
WHEEL_SUFFIX = "-py3-none-any.whl"
|
||||
|
||||
PROJECT_FILE = "project.yml"
|
||||
PROJECT_LOCK = "project.lock"
|
||||
COMMAND = "python -m spacy"
|
||||
NAME = "spacy"
|
||||
HELP = """spaCy Command-line Interface
|
||||
|
@ -136,148 +132,6 @@ def _parse_override(value: Any) -> Any:
|
|||
return str(value)
|
||||
|
||||
|
||||
def load_project_config(
|
||||
path: Path, interpolate: bool = True, overrides: Dict[str, Any] = SimpleFrozenDict()
|
||||
) -> Dict[str, Any]:
|
||||
"""Load the project.yml file from a directory and validate it. Also make
|
||||
sure that all directories defined in the config exist.
|
||||
|
||||
path (Path): The path to the project directory.
|
||||
interpolate (bool): Whether to substitute project variables.
|
||||
overrides (Dict[str, Any]): Optional config overrides.
|
||||
RETURNS (Dict[str, Any]): The loaded project.yml.
|
||||
"""
|
||||
config_path = path / PROJECT_FILE
|
||||
if not config_path.exists():
|
||||
msg.fail(f"Can't find {PROJECT_FILE}", config_path, exits=1)
|
||||
invalid_err = f"Invalid {PROJECT_FILE}. Double-check that the YAML is correct."
|
||||
try:
|
||||
config = srsly.read_yaml(config_path)
|
||||
except ValueError as e:
|
||||
msg.fail(invalid_err, e, exits=1)
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
if errors:
|
||||
msg.fail(invalid_err)
|
||||
print("\n".join(errors))
|
||||
sys.exit(1)
|
||||
validate_project_version(config)
|
||||
validate_project_commands(config)
|
||||
if interpolate:
|
||||
err = f"{PROJECT_FILE} validation error"
|
||||
with show_validation_error(title=err, hint_fill=False):
|
||||
config = substitute_project_variables(config, overrides)
|
||||
# Make sure directories defined in config exist
|
||||
for subdir in config.get("directories", []):
|
||||
dir_path = path / subdir
|
||||
if not dir_path.exists():
|
||||
dir_path.mkdir(parents=True)
|
||||
return config
|
||||
|
||||
|
||||
def substitute_project_variables(
|
||||
config: Dict[str, Any],
|
||||
overrides: Dict[str, Any] = SimpleFrozenDict(),
|
||||
key: str = "vars",
|
||||
env_key: str = "env",
|
||||
) -> Dict[str, Any]:
|
||||
"""Interpolate variables in the project file using the config system.
|
||||
|
||||
config (Dict[str, Any]): The project config.
|
||||
overrides (Dict[str, Any]): Optional config overrides.
|
||||
key (str): Key containing variables in project config.
|
||||
env_key (str): Key containing environment variable mapping in project config.
|
||||
RETURNS (Dict[str, Any]): The interpolated project config.
|
||||
"""
|
||||
config.setdefault(key, {})
|
||||
config.setdefault(env_key, {})
|
||||
# Substitute references to env vars with their values
|
||||
for config_var, env_var in config[env_key].items():
|
||||
config[env_key][config_var] = _parse_override(os.environ.get(env_var, ""))
|
||||
# Need to put variables in the top scope again so we can have a top-level
|
||||
# section "project" (otherwise, a list of commands in the top scope wouldn't)
|
||||
# be allowed by Thinc's config system
|
||||
cfg = Config({"project": config, key: config[key], env_key: config[env_key]})
|
||||
cfg = Config().from_str(cfg.to_str(), overrides=overrides)
|
||||
interpolated = cfg.interpolate()
|
||||
return dict(interpolated["project"])
|
||||
|
||||
|
||||
def validate_project_version(config: Dict[str, Any]) -> None:
|
||||
"""If the project defines a compatible spaCy version range, chec that it's
|
||||
compatible with the current version of spaCy.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
spacy_version = config.get("spacy_version", None)
|
||||
if spacy_version and not is_compatible_version(about.__version__, spacy_version):
|
||||
err = (
|
||||
f"The {PROJECT_FILE} specifies a spaCy version range ({spacy_version}) "
|
||||
f"that's not compatible with the version of spaCy you're running "
|
||||
f"({about.__version__}). You can edit version requirement in the "
|
||||
f"{PROJECT_FILE} to load it, but the project may not run as expected."
|
||||
)
|
||||
msg.fail(err, exits=1)
|
||||
|
||||
|
||||
def validate_project_commands(config: Dict[str, Any]) -> None:
|
||||
"""Check that project commands and workflows are valid, don't contain
|
||||
duplicates, don't clash and only refer to commands that exist.
|
||||
|
||||
config (Dict[str, Any]): The loaded config.
|
||||
"""
|
||||
command_names = [cmd["name"] for cmd in config.get("commands", [])]
|
||||
workflows = config.get("workflows", {})
|
||||
duplicates = set([cmd for cmd in command_names if command_names.count(cmd) > 1])
|
||||
if duplicates:
|
||||
err = f"Duplicate commands defined in {PROJECT_FILE}: {', '.join(duplicates)}"
|
||||
msg.fail(err, exits=1)
|
||||
for workflow_name, workflow_steps in workflows.items():
|
||||
if workflow_name in command_names:
|
||||
err = f"Can't use workflow name '{workflow_name}': name already exists as a command"
|
||||
msg.fail(err, exits=1)
|
||||
for step in workflow_steps:
|
||||
if step not in command_names:
|
||||
msg.fail(
|
||||
f"Unknown command specified in workflow '{workflow_name}': {step}",
|
||||
f"Workflows can only refer to commands defined in the 'commands' "
|
||||
f"section of the {PROJECT_FILE}.",
|
||||
exits=1,
|
||||
)
|
||||
|
||||
|
||||
def get_hash(data, exclude: Iterable[str] = tuple()) -> str:
|
||||
"""Get the hash for a JSON-serializable object.
|
||||
|
||||
data: The data to hash.
|
||||
exclude (Iterable[str]): Top-level keys to exclude if data is a dict.
|
||||
RETURNS (str): The hash.
|
||||
"""
|
||||
if isinstance(data, dict):
|
||||
data = {k: v for k, v in data.items() if k not in exclude}
|
||||
data_str = srsly.json_dumps(data, sort_keys=True).encode("utf8")
|
||||
return hashlib.md5(data_str).hexdigest()
|
||||
|
||||
|
||||
def get_checksum(path: Union[Path, str]) -> str:
|
||||
"""Get the checksum for a file or directory given its file path. If a
|
||||
directory path is provided, this uses all files in that directory.
|
||||
|
||||
path (Union[Path, str]): The file or directory path.
|
||||
RETURNS (str): The checksum.
|
||||
"""
|
||||
path = Path(path)
|
||||
if not (path.is_file() or path.is_dir()):
|
||||
msg.fail(f"Can't get checksum for {path}: not a file or directory", exits=1)
|
||||
if path.is_file():
|
||||
return hashlib.md5(Path(path).read_bytes()).hexdigest()
|
||||
else:
|
||||
# TODO: this is currently pretty slow
|
||||
dir_checksum = hashlib.md5()
|
||||
for sub_file in sorted(fp for fp in path.rglob("*") if fp.is_file()):
|
||||
dir_checksum.update(sub_file.read_bytes())
|
||||
return dir_checksum.hexdigest()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def show_validation_error(
|
||||
file_path: Optional[Union[str, Path]] = None,
|
||||
|
@ -335,166 +189,10 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
|||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||
|
||||
|
||||
def upload_file(src: Path, dest: Union[str, "FluidPath"]) -> None:
|
||||
"""Upload a file.
|
||||
|
||||
src (Path): The source path.
|
||||
url (str): The destination URL to upload to.
|
||||
"""
|
||||
import smart_open
|
||||
|
||||
# Create parent directories for local paths
|
||||
if isinstance(dest, Path):
|
||||
if not dest.parent.exists():
|
||||
dest.parent.mkdir(parents=True)
|
||||
|
||||
dest = str(dest)
|
||||
with smart_open.open(dest, mode="wb") as output_file:
|
||||
with src.open(mode="rb") as input_file:
|
||||
output_file.write(input_file.read())
|
||||
|
||||
|
||||
def download_file(
|
||||
src: Union[str, "FluidPath"], dest: Path, *, force: bool = False
|
||||
) -> None:
|
||||
"""Download a file using smart_open.
|
||||
|
||||
url (str): The URL of the file.
|
||||
dest (Path): The destination path.
|
||||
force (bool): Whether to force download even if file exists.
|
||||
If False, the download will be skipped.
|
||||
"""
|
||||
import smart_open
|
||||
|
||||
if dest.exists() and not force:
|
||||
return None
|
||||
src = str(src)
|
||||
with smart_open.open(src, mode="rb", compression="disable") as input_file:
|
||||
with dest.open(mode="wb") as output_file:
|
||||
shutil.copyfileobj(input_file, output_file)
|
||||
|
||||
|
||||
def ensure_pathy(path):
|
||||
"""Temporary helper to prevent importing Pathy globally (which can cause
|
||||
slow and annoying Google Cloud warning)."""
|
||||
from pathy import Pathy # noqa: F811
|
||||
|
||||
return Pathy.fluid(path)
|
||||
|
||||
|
||||
def git_checkout(
|
||||
repo: str, subpath: str, dest: Path, *, branch: str = "master", sparse: bool = False
|
||||
):
|
||||
git_version = get_git_version()
|
||||
if dest.exists():
|
||||
msg.fail("Destination of checkout must not exist", exits=1)
|
||||
if not dest.parent.exists():
|
||||
msg.fail("Parent of destination of checkout must exist", exits=1)
|
||||
if sparse and git_version >= (2, 22):
|
||||
return git_sparse_checkout(repo, subpath, dest, branch)
|
||||
elif sparse:
|
||||
# Only show warnings if the user explicitly wants sparse checkout but
|
||||
# the Git version doesn't support it
|
||||
err_old = (
|
||||
f"You're running an old version of Git (v{git_version[0]}.{git_version[1]}) "
|
||||
f"that doesn't fully support sparse checkout yet."
|
||||
)
|
||||
err_unk = "You're running an unknown version of Git, so sparse checkout has been disabled."
|
||||
msg.warn(
|
||||
f"{err_unk if git_version == (0, 0) else err_old} "
|
||||
f"This means that more files than necessary may be downloaded "
|
||||
f"temporarily. To only download the files needed, make sure "
|
||||
f"you're using Git v2.22 or above."
|
||||
)
|
||||
with make_tempdir() as tmp_dir:
|
||||
cmd = f"git -C {tmp_dir} clone {repo} . -b {branch}"
|
||||
run_command(cmd, capture=True)
|
||||
# We need Path(name) to make sure we also support subdirectories
|
||||
try:
|
||||
source_path = tmp_dir / Path(subpath)
|
||||
if not is_subpath_of(tmp_dir, source_path):
|
||||
err = f"'{subpath}' is a path outside of the cloned repository."
|
||||
msg.fail(err, repo, exits=1)
|
||||
shutil.copytree(str(source_path), str(dest))
|
||||
except FileNotFoundError:
|
||||
err = f"Can't clone {subpath}. Make sure the directory exists in the repo (branch '{branch}')"
|
||||
msg.fail(err, repo, exits=1)
|
||||
|
||||
|
||||
def git_sparse_checkout(repo, subpath, dest, branch):
|
||||
# We're using Git, partial clone and sparse checkout to
|
||||
# only clone the files we need
|
||||
# This ends up being RIDICULOUS. omg.
|
||||
# So, every tutorial and SO post talks about 'sparse checkout'...But they
|
||||
# go and *clone* the whole repo. Worthless. And cloning part of a repo
|
||||
# turns out to be completely broken. The only way to specify a "path" is..
|
||||
# a path *on the server*? The contents of which, specifies the paths. Wat.
|
||||
# Obviously this is hopelessly broken and insecure, because you can query
|
||||
# arbitrary paths on the server! So nobody enables this.
|
||||
# What we have to do is disable *all* files. We could then just checkout
|
||||
# the path, and it'd "work", but be hopelessly slow...Because it goes and
|
||||
# transfers every missing object one-by-one. So the final piece is that we
|
||||
# need to use some weird git internals to fetch the missings in bulk, and
|
||||
# *that* we can do by path.
|
||||
# We're using Git and sparse checkout to only clone the files we need
|
||||
with make_tempdir() as tmp_dir:
|
||||
# This is the "clone, but don't download anything" part.
|
||||
cmd = (
|
||||
f"git clone {repo} {tmp_dir} --no-checkout --depth 1 "
|
||||
f"-b {branch} --filter=blob:none"
|
||||
)
|
||||
run_command(cmd)
|
||||
# Now we need to find the missing filenames for the subpath we want.
|
||||
# Looking for this 'rev-list' command in the git --help? Hah.
|
||||
cmd = f"git -C {tmp_dir} rev-list --objects --all --missing=print -- {subpath}"
|
||||
ret = run_command(cmd, capture=True)
|
||||
git_repo = _http_to_git(repo)
|
||||
# Now pass those missings into another bit of git internals
|
||||
missings = " ".join([x[1:] for x in ret.stdout.split() if x.startswith("?")])
|
||||
if not missings:
|
||||
err = (
|
||||
f"Could not find any relevant files for '{subpath}'. "
|
||||
f"Did you specify a correct and complete path within repo '{repo}' "
|
||||
f"and branch {branch}?"
|
||||
)
|
||||
msg.fail(err, exits=1)
|
||||
cmd = f"git -C {tmp_dir} fetch-pack {git_repo} {missings}"
|
||||
run_command(cmd, capture=True)
|
||||
# And finally, we can checkout our subpath
|
||||
cmd = f"git -C {tmp_dir} checkout {branch} {subpath}"
|
||||
run_command(cmd, capture=True)
|
||||
|
||||
# Get a subdirectory of the cloned path, if appropriate
|
||||
source_path = tmp_dir / Path(subpath)
|
||||
if not is_subpath_of(tmp_dir, source_path):
|
||||
err = f"'{subpath}' is a path outside of the cloned repository."
|
||||
msg.fail(err, repo, exits=1)
|
||||
|
||||
shutil.move(str(source_path), str(dest))
|
||||
|
||||
|
||||
def git_repo_branch_exists(repo: str, branch: str) -> bool:
|
||||
"""Uses 'git ls-remote' to check if a repository and branch exists
|
||||
|
||||
repo (str): URL to get repo.
|
||||
branch (str): Branch on repo to check.
|
||||
RETURNS (bool): True if repo:branch exists.
|
||||
"""
|
||||
get_git_version()
|
||||
cmd = f"git ls-remote {repo} {branch}"
|
||||
# We might be tempted to use `--exit-code` with `git ls-remote`, but
|
||||
# `run_command` handles the `returncode` for us, so we'll rely on
|
||||
# the fact that stdout returns '' if the requested branch doesn't exist
|
||||
ret = run_command(cmd, capture=True)
|
||||
exists = ret.stdout != ""
|
||||
return exists
|
||||
|
||||
|
||||
def get_git_version(
|
||||
error: str = "Could not run 'git'. Make sure it's installed and the executable is available.",
|
||||
) -> Tuple[int, int]:
|
||||
"""Get the version of git and raise an error if calling 'git --version' fails.
|
||||
|
||||
error (str): The error message to show.
|
||||
RETURNS (Tuple[int, int]): The version as a (major, minor) tuple. Returns
|
||||
(0, 0) if the version couldn't be determined.
|
||||
|
@ -510,30 +208,6 @@ def get_git_version(
|
|||
return int(version[0]), int(version[1])
|
||||
|
||||
|
||||
def _http_to_git(repo: str) -> str:
|
||||
if repo.startswith("http://"):
|
||||
repo = repo.replace(r"http://", r"https://")
|
||||
if repo.startswith(r"https://"):
|
||||
repo = repo.replace("https://", "git@").replace("/", ":", 1)
|
||||
if repo.endswith("/"):
|
||||
repo = repo[:-1]
|
||||
repo = f"{repo}.git"
|
||||
return repo
|
||||
|
||||
|
||||
def is_subpath_of(parent, child):
|
||||
"""
|
||||
Check whether `child` is a path contained within `parent`.
|
||||
"""
|
||||
# Based on https://stackoverflow.com/a/37095733 .
|
||||
|
||||
# In Python 3.9, the `Path.is_relative_to()` method will supplant this, so
|
||||
# we can stop using crusty old os.path functions.
|
||||
parent_realpath = os.path.realpath(parent)
|
||||
child_realpath = os.path.realpath(child)
|
||||
return os.path.commonpath([parent_realpath, child_realpath]) == parent_realpath
|
||||
|
||||
|
||||
@overload
|
||||
def string_to_list(value: str, intify: Literal[False] = ...) -> List[str]:
|
||||
...
|
||||
|
|
|
@ -553,8 +553,6 @@ class Errors(metaclass=ErrorsWithCodes):
|
|||
"vectors, not {mode} vectors.")
|
||||
E851 = ("The 'textcat' component labels should only have values of 0 or 1, "
|
||||
"but found value of '{val}'.")
|
||||
E852 = ("The tar file pulled from the remote attempted an unsafe path "
|
||||
"traversal.")
|
||||
E853 = ("Unsupported component factory name '{name}'. The character '.' is "
|
||||
"not permitted in factory names.")
|
||||
E854 = ("Unable to set doc.ents. Check that the 'ents_filter' does not "
|
||||
|
|
|
@ -443,66 +443,6 @@ CONFIG_SCHEMAS = {
|
|||
"initialize": ConfigSchemaInit,
|
||||
}
|
||||
|
||||
|
||||
# Project config Schema
|
||||
|
||||
|
||||
class ProjectConfigAssetGitItem(BaseModel):
|
||||
# fmt: off
|
||||
repo: StrictStr = Field(..., title="URL of Git repo to download from")
|
||||
path: StrictStr = Field(..., title="File path or sub-directory to download (used for sparse checkout)")
|
||||
branch: StrictStr = Field("master", title="Branch to clone from")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ProjectConfigAssetURL(BaseModel):
|
||||
# fmt: off
|
||||
dest: StrictStr = Field(..., title="Destination of downloaded asset")
|
||||
url: Optional[StrictStr] = Field(None, title="URL of asset")
|
||||
checksum: Optional[str] = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
|
||||
description: StrictStr = Field("", title="Description of asset")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ProjectConfigAssetGit(BaseModel):
|
||||
# fmt: off
|
||||
git: ProjectConfigAssetGitItem = Field(..., title="Git repo information")
|
||||
checksum: Optional[str] = Field(None, title="MD5 hash of file", regex=r"([a-fA-F\d]{32})")
|
||||
description: Optional[StrictStr] = Field(None, title="Description of asset")
|
||||
# fmt: on
|
||||
|
||||
|
||||
class ProjectConfigCommand(BaseModel):
|
||||
# fmt: off
|
||||
name: StrictStr = Field(..., title="Name of command")
|
||||
help: Optional[StrictStr] = Field(None, title="Command description")
|
||||
script: List[StrictStr] = Field([], title="List of CLI commands to run, in order")
|
||||
deps: List[StrictStr] = Field([], title="File dependencies required by this command")
|
||||
outputs: List[StrictStr] = Field([], title="Outputs produced by this command")
|
||||
outputs_no_cache: List[StrictStr] = Field([], title="Outputs not tracked by DVC (DVC only)")
|
||||
no_skip: bool = Field(False, title="Never skip this command, even if nothing changed")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
title = "A single named command specified in a project config"
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class ProjectConfigSchema(BaseModel):
|
||||
# fmt: off
|
||||
vars: Dict[StrictStr, Any] = Field({}, title="Optional variables to substitute in commands")
|
||||
env: Dict[StrictStr, Any] = Field({}, title="Optional variable names to substitute in commands, mapped to environment variable names")
|
||||
assets: List[Union[ProjectConfigAssetURL, ProjectConfigAssetGit]] = Field([], title="Data assets")
|
||||
workflows: Dict[StrictStr, List[StrictStr]] = Field({}, title="Named workflows, mapped to list of project commands to run in order")
|
||||
commands: List[ProjectConfigCommand] = Field([], title="Project command shortucts")
|
||||
title: Optional[str] = Field(None, title="Project title")
|
||||
spacy_version: Optional[StrictStr] = Field(None, title="spaCy version range that the project is compatible with")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
title = "Schema for project configuration file"
|
||||
|
||||
|
||||
# Recommendations for init config workflows
|
||||
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@ import os
|
|||
import math
|
||||
from collections import Counter
|
||||
from typing import Tuple, List, Dict, Any
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import spacy
|
||||
|
@ -17,11 +16,8 @@ from weasel.cli.run import _check_requirements
|
|||
|
||||
from spacy import about
|
||||
from spacy.cli import info
|
||||
from spacy.cli._util import is_subpath_of, load_project_config, walk_directory
|
||||
from spacy.cli._util import walk_directory
|
||||
from spacy.cli._util import parse_config_overrides, string_to_list
|
||||
from spacy.cli._util import substitute_project_variables
|
||||
from spacy.cli._util import validate_project_commands
|
||||
from spacy.cli._util import upload_file, download_file
|
||||
from spacy.cli.debug_data import _compile_gold, _get_labels_from_model
|
||||
from spacy.cli.debug_data import _get_labels_from_spancat
|
||||
from spacy.cli.debug_data import _get_distribution, _get_kl_divergence
|
||||
|
@ -39,7 +35,7 @@ from spacy.cli.find_threshold import find_threshold
|
|||
from spacy.lang.en import English
|
||||
from spacy.lang.nl import Dutch
|
||||
from spacy.language import Language
|
||||
from spacy.schemas import ProjectConfigSchema, RecommendationSchema, validate
|
||||
from spacy.schemas import RecommendationSchema
|
||||
from spacy.tokens import Doc, DocBin
|
||||
from spacy.tokens.span import Span
|
||||
from spacy.training import Example, docs_to_json, offsets_to_biluo_tags
|
||||
|
@ -125,25 +121,6 @@ def test_issue7055():
|
|||
assert "model" in filled_cfg["components"]["ner"]
|
||||
|
||||
|
||||
@pytest.mark.issue(11235)
|
||||
def test_issue11235():
|
||||
"""
|
||||
Test that the cli handles interpolation in the directory names correctly when loading project config.
|
||||
"""
|
||||
lang_var = "en"
|
||||
variables = {"lang": lang_var}
|
||||
commands = [{"name": "x", "script": ["hello ${vars.lang}"]}]
|
||||
directories = ["cfg", "${vars.lang}_model"]
|
||||
project = {"commands": commands, "vars": variables, "directories": directories}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
# Check that the directories are interpolated and created correctly
|
||||
assert os.path.exists(d / "cfg")
|
||||
assert os.path.exists(d / f"{lang_var}_model")
|
||||
assert cfg["commands"][0]["script"][0] == f"hello {lang_var}"
|
||||
|
||||
|
||||
def test_cli_info():
|
||||
nlp = Dutch()
|
||||
nlp.add_pipe("textcat")
|
||||
|
@ -370,136 +347,6 @@ def test_cli_converters_conll_ner_to_docs():
|
|||
assert ent.text in ["New York City", "London"]
|
||||
|
||||
|
||||
def test_project_config_validation_full():
|
||||
config = {
|
||||
"vars": {"some_var": 20},
|
||||
"directories": ["assets", "configs", "corpus", "scripts", "training"],
|
||||
"assets": [
|
||||
{
|
||||
"dest": "x",
|
||||
"extra": True,
|
||||
"url": "https://example.com",
|
||||
"checksum": "63373dd656daa1fd3043ce166a59474c",
|
||||
},
|
||||
{
|
||||
"dest": "y",
|
||||
"git": {
|
||||
"repo": "https://github.com/example/repo",
|
||||
"branch": "develop",
|
||||
"path": "y",
|
||||
},
|
||||
},
|
||||
{
|
||||
"dest": "z",
|
||||
"extra": False,
|
||||
"url": "https://example.com",
|
||||
"checksum": "63373dd656daa1fd3043ce166a59474c",
|
||||
},
|
||||
],
|
||||
"commands": [
|
||||
{
|
||||
"name": "train",
|
||||
"help": "Train a model",
|
||||
"script": ["python -m spacy train config.cfg -o training"],
|
||||
"deps": ["config.cfg", "corpus/training.spcy"],
|
||||
"outputs": ["training/model-best"],
|
||||
},
|
||||
{"name": "test", "script": ["pytest", "custom.py"], "no_skip": True},
|
||||
],
|
||||
"workflows": {"all": ["train", "test"], "train": ["train"]},
|
||||
}
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
assert not errors
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config",
|
||||
[
|
||||
{"commands": [{"name": "a"}, {"name": "a"}]},
|
||||
{"commands": [{"name": "a"}], "workflows": {"a": []}},
|
||||
{"commands": [{"name": "a"}], "workflows": {"b": ["c"]}},
|
||||
],
|
||||
)
|
||||
def test_project_config_validation1(config):
|
||||
with pytest.raises(SystemExit):
|
||||
validate_project_commands(config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,n_errors",
|
||||
[
|
||||
({"commands": {"a": []}}, 1),
|
||||
({"commands": [{"help": "..."}]}, 1),
|
||||
({"commands": [{"name": "a", "extra": "b"}]}, 1),
|
||||
({"commands": [{"extra": "b"}]}, 2),
|
||||
({"commands": [{"name": "a", "deps": [123]}]}, 1),
|
||||
],
|
||||
)
|
||||
def test_project_config_validation2(config, n_errors):
|
||||
errors = validate(ProjectConfigSchema, config)
|
||||
assert len(errors) == n_errors
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"int_value",
|
||||
[10, pytest.param("10", marks=pytest.mark.xfail)],
|
||||
)
|
||||
def test_project_config_interpolation(int_value):
|
||||
variables = {"a": int_value, "b": {"c": "foo", "d": True}}
|
||||
commands = [
|
||||
{"name": "x", "script": ["hello ${vars.a} ${vars.b.c}"]},
|
||||
{"name": "y", "script": ["${vars.b.c} ${vars.b.d}"]},
|
||||
]
|
||||
project = {"commands": commands, "vars": variables}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
assert type(cfg) == dict
|
||||
assert type(cfg["commands"]) == list
|
||||
assert cfg["commands"][0]["script"][0] == "hello 10 foo"
|
||||
assert cfg["commands"][1]["script"][0] == "foo true"
|
||||
commands = [{"name": "x", "script": ["hello ${vars.a} ${vars.b.e}"]}]
|
||||
project = {"commands": commands, "vars": variables}
|
||||
with pytest.raises(ConfigValidationError):
|
||||
substitute_project_variables(project)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"greeting",
|
||||
[342, "everyone", "tout le monde", pytest.param("42", marks=pytest.mark.xfail)],
|
||||
)
|
||||
def test_project_config_interpolation_override(greeting):
|
||||
variables = {"a": "world"}
|
||||
commands = [
|
||||
{"name": "x", "script": ["hello ${vars.a}"]},
|
||||
]
|
||||
overrides = {"vars.a": greeting}
|
||||
project = {"commands": commands, "vars": variables}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d, overrides=overrides)
|
||||
assert type(cfg) == dict
|
||||
assert type(cfg["commands"]) == list
|
||||
assert cfg["commands"][0]["script"][0] == f"hello {greeting}"
|
||||
|
||||
|
||||
def test_project_config_interpolation_env():
|
||||
variables = {"a": 10}
|
||||
env_var = "SPACY_TEST_FOO"
|
||||
env_vars = {"foo": env_var}
|
||||
commands = [{"name": "x", "script": ["hello ${vars.a} ${env.foo}"]}]
|
||||
project = {"commands": commands, "vars": variables, "env": env_vars}
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
assert cfg["commands"][0]["script"][0] == "hello 10 "
|
||||
os.environ[env_var] = "123"
|
||||
with make_tempdir() as d:
|
||||
srsly.write_yaml(d / "project.yml", project)
|
||||
cfg = load_project_config(d)
|
||||
assert cfg["commands"][0]["script"][0] == "hello 10 123"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"args,expected",
|
||||
[
|
||||
|
@ -709,21 +556,6 @@ def test_get_third_party_dependencies():
|
|||
get_third_party_dependencies(nlp.config)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"parent,child,expected",
|
||||
[
|
||||
("/tmp", "/tmp", True),
|
||||
("/tmp", "/", False),
|
||||
("/tmp", "/tmp/subdir", True),
|
||||
("/tmp", "/tmpdir", False),
|
||||
("/tmp", "/tmp/subdir/..", True),
|
||||
("/tmp", "/tmp/..", False),
|
||||
],
|
||||
)
|
||||
def test_is_subpath_of(parent, child, expected):
|
||||
assert is_subpath_of(parent, child) == expected
|
||||
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize(
|
||||
"factory_name,pipe_name",
|
||||
|
@ -968,60 +800,6 @@ def test_applycli_user_data():
|
|||
assert result[0]._.ext == val
|
||||
|
||||
|
||||
def test_local_remote_storage():
|
||||
with make_tempdir() as d:
|
||||
filename = "a.txt"
|
||||
|
||||
content_hashes = ("aaaa", "cccc", "bbbb")
|
||||
for i, content_hash in enumerate(content_hashes):
|
||||
# make sure that each subsequent file has a later timestamp
|
||||
if i > 0:
|
||||
time.sleep(1)
|
||||
content = f"{content_hash} content"
|
||||
loc_file = d / "root" / filename
|
||||
if not loc_file.parent.exists():
|
||||
loc_file.parent.mkdir(parents=True)
|
||||
with loc_file.open(mode="w") as file_:
|
||||
file_.write(content)
|
||||
|
||||
# push first version to remote storage
|
||||
remote = RemoteStorage(d / "root", str(d / "remote"))
|
||||
remote.push(filename, "aaaa", content_hash)
|
||||
|
||||
# retrieve with full hashes
|
||||
loc_file.unlink()
|
||||
remote.pull(filename, command_hash="aaaa", content_hash=content_hash)
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
# retrieve with command hash
|
||||
loc_file.unlink()
|
||||
remote.pull(filename, command_hash="aaaa")
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
# retrieve with content hash
|
||||
loc_file.unlink()
|
||||
remote.pull(filename, content_hash=content_hash)
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
# retrieve with no hashes
|
||||
loc_file.unlink()
|
||||
remote.pull(filename)
|
||||
with loc_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
|
||||
def test_local_remote_storage_pull_missing():
|
||||
# pulling from a non-existent remote pulls nothing gracefully
|
||||
with make_tempdir() as d:
|
||||
filename = "a.txt"
|
||||
remote = RemoteStorage(d / "root", str(d / "remote"))
|
||||
assert remote.pull(filename, command_hash="aaaa") is None
|
||||
assert remote.pull(filename) is None
|
||||
|
||||
|
||||
def test_cli_find_threshold(capsys):
|
||||
def make_examples(nlp: Language) -> List[Example]:
|
||||
docs: List[Example] = []
|
||||
|
@ -1132,63 +910,6 @@ def test_cli_find_threshold(capsys):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
||||
@pytest.mark.parametrize(
|
||||
"reqs,output",
|
||||
[
|
||||
[
|
||||
"""
|
||||
spacy
|
||||
|
||||
# comment
|
||||
|
||||
thinc""",
|
||||
(False, False),
|
||||
],
|
||||
[
|
||||
"""# comment
|
||||
--some-flag
|
||||
spacy""",
|
||||
(False, False),
|
||||
],
|
||||
[
|
||||
"""# comment
|
||||
--some-flag
|
||||
spacy; python_version >= '3.6'""",
|
||||
(False, False),
|
||||
],
|
||||
[
|
||||
"""# comment
|
||||
spacyunknowndoesnotexist12345""",
|
||||
(True, False),
|
||||
],
|
||||
],
|
||||
)
|
||||
def test_project_check_requirements(reqs, output):
|
||||
import pkg_resources
|
||||
|
||||
# excessive guard against unlikely package name
|
||||
try:
|
||||
pkg_resources.require("spacyunknowndoesnotexist12345")
|
||||
except pkg_resources.DistributionNotFound:
|
||||
assert output == _check_requirements([req.strip() for req in reqs.split("\n")])
|
||||
|
||||
|
||||
def test_upload_download_local_file():
|
||||
with make_tempdir() as d1, make_tempdir() as d2:
|
||||
filename = "f.txt"
|
||||
content = "content"
|
||||
local_file = d1 / filename
|
||||
remote_file = d2 / filename
|
||||
with local_file.open(mode="w") as file_:
|
||||
file_.write(content)
|
||||
upload_file(local_file, remote_file)
|
||||
local_file.unlink()
|
||||
download_file(remote_file, local_file)
|
||||
with local_file.open(mode="r") as file_:
|
||||
assert file_.read() == content
|
||||
|
||||
|
||||
def test_walk_directory():
|
||||
with make_tempdir() as d:
|
||||
files = [
|
||||
|
|
|
@ -78,7 +78,6 @@ logger.addHandler(logger_stream_handler)
|
|||
|
||||
class ENV_VARS:
|
||||
CONFIG_OVERRIDES = "SPACY_CONFIG_OVERRIDES"
|
||||
PROJECT_USE_GIT_VERSION = "SPACY_PROJECT_USE_GIT_VERSION"
|
||||
|
||||
|
||||
class registry(thinc.registry):
|
||||
|
@ -951,23 +950,12 @@ def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
|
|||
|
||||
def split_command(command: str) -> List[str]:
|
||||
"""Split a string command using shlex. Handles platform compatibility.
|
||||
|
||||
command (str) : The command to split
|
||||
RETURNS (List[str]): The split command.
|
||||
"""
|
||||
return shlex.split(command, posix=not is_windows)
|
||||
|
||||
|
||||
def join_command(command: List[str]) -> str:
|
||||
"""Join a command using shlex. shlex.join is only available for Python 3.8+,
|
||||
so we're using a workaround here.
|
||||
|
||||
command (List[str]): The command to join.
|
||||
RETURNS (str): The joined command
|
||||
"""
|
||||
return " ".join(shlex.quote(cmd) for cmd in command)
|
||||
|
||||
|
||||
def run_command(
|
||||
command: Union[str, List[str]],
|
||||
*,
|
||||
|
@ -976,7 +964,6 @@ def run_command(
|
|||
) -> subprocess.CompletedProcess:
|
||||
"""Run a command on the command line as a subprocess. If the subprocess
|
||||
returns a non-zero exit code, a system exit is performed.
|
||||
|
||||
command (str / List[str]): The command. If provided as a string, the
|
||||
string will be split using shlex.split.
|
||||
stdin (Optional[Any]): stdin to read from or None.
|
||||
|
@ -1027,7 +1014,6 @@ def run_command(
|
|||
@contextmanager
|
||||
def working_dir(path: Union[str, Path]) -> Iterator[Path]:
|
||||
"""Change current working directory and returns to previous on exit.
|
||||
|
||||
path (str / Path): The directory to navigate to.
|
||||
YIELDS (Path): The absolute path to the current working directory. This
|
||||
should be used if the block needs to perform actions within the working
|
||||
|
@ -1046,7 +1032,6 @@ def working_dir(path: Union[str, Path]) -> Iterator[Path]:
|
|||
def make_tempdir() -> Generator[Path, None, None]:
|
||||
"""Execute a block in a temporary directory and remove the directory and
|
||||
its contents at the end of the with block.
|
||||
|
||||
YIELDS (Path): The path of the temp directory.
|
||||
"""
|
||||
d = Path(tempfile.mkdtemp())
|
||||
|
@ -1064,15 +1049,6 @@ def make_tempdir() -> Generator[Path, None, None]:
|
|||
warnings.warn(Warnings.W091.format(dir=d, msg=e))
|
||||
|
||||
|
||||
def is_cwd(path: Union[Path, str]) -> bool:
|
||||
"""Check whether a path is the current working directory.
|
||||
|
||||
path (Union[Path, str]): The directory path.
|
||||
RETURNS (bool): Whether the path is the current working directory.
|
||||
"""
|
||||
return str(Path(path).resolve()).lower() == str(Path.cwd().resolve()).lower()
|
||||
|
||||
|
||||
def is_in_jupyter() -> bool:
|
||||
"""Check if user is running spaCy from a Jupyter notebook by detecting the
|
||||
IPython kernel. Mainly used for the displaCy visualizer.
|
||||
|
|
Loading…
Reference in New Issue
Block a user