mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
0518c36f04
The 'direct' option in 'spacy download' is supposed to only download from our model releases repository. However, users were able to pass in a relative path, allowing download from arbitrary repositories. This meant that a service that sourced strings from user input and which used the direct option would allow users to install arbitrary packages.
177 lines
6.3 KiB
Python
177 lines
6.3 KiB
Python
import sys
|
|
from typing import Optional, Sequence
|
|
from urllib.parse import urljoin
|
|
|
|
import requests
|
|
import typer
|
|
from wasabi import msg
|
|
|
|
from .. import about
|
|
from ..errors import OLD_MODEL_SHORTCUTS
|
|
from ..util import (
|
|
get_minor_version,
|
|
is_in_interactive,
|
|
is_in_jupyter,
|
|
is_package,
|
|
is_prerelease_version,
|
|
run_command,
|
|
)
|
|
from ._util import SDIST_SUFFIX, WHEEL_SUFFIX, Arg, Opt, app
|
|
|
|
|
|
@app.command(
|
|
"download",
|
|
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
)
|
|
def download_cli(
|
|
# fmt: off
|
|
ctx: typer.Context,
|
|
model: str = Arg(..., help="Name of pipeline package to download"),
|
|
direct: bool = Opt(False, "--direct", "-d", "-D", help="Force direct download of name + version"),
|
|
sdist: bool = Opt(False, "--sdist", "-S", help="Download sdist (.tar.gz) archive instead of pre-built binary wheel"),
|
|
# fmt: on
|
|
):
|
|
"""
|
|
Download compatible trained pipeline from the default download path using
|
|
pip. If --direct flag is set, the command expects the full package name with
|
|
version. For direct downloads, the compatibility check will be skipped. All
|
|
additional arguments provided to this command will be passed to `pip install`
|
|
on package installation.
|
|
|
|
DOCS: https://spacy.io/api/cli#download
|
|
AVAILABLE PACKAGES: https://spacy.io/models
|
|
"""
|
|
download(model, direct, sdist, *ctx.args)
|
|
|
|
|
|
def download(
|
|
model: str,
|
|
direct: bool = False,
|
|
sdist: bool = False,
|
|
*pip_args,
|
|
) -> None:
|
|
if (
|
|
not (is_package("spacy") or is_package("spacy-nightly"))
|
|
and "--no-deps" not in pip_args
|
|
):
|
|
msg.warn(
|
|
"Skipping pipeline package dependencies and setting `--no-deps`. "
|
|
"You don't seem to have the spaCy package itself installed "
|
|
"(maybe because you've built from source?), so installing the "
|
|
"package dependencies would cause spaCy to be downloaded, which "
|
|
"probably isn't what you want. If the pipeline package has other "
|
|
"dependencies, you'll have to install them manually."
|
|
)
|
|
pip_args = pip_args + ("--no-deps",)
|
|
if direct:
|
|
# Reject model names with '/', in order to prevent shenanigans.
|
|
if "/" in model:
|
|
msg.fail(
|
|
title="Model download rejected",
|
|
text=f"Cannot download model '{model}'. Models are expected to be file names, not URLs or fragments",
|
|
exits=True,
|
|
)
|
|
components = model.split("-")
|
|
model_name = "".join(components[:-1])
|
|
version = components[-1]
|
|
else:
|
|
model_name = model
|
|
if model in OLD_MODEL_SHORTCUTS:
|
|
msg.warn(
|
|
f"As of spaCy v3.0, shortcuts like '{model}' are deprecated. Please "
|
|
f"use the full pipeline package name '{OLD_MODEL_SHORTCUTS[model]}' instead."
|
|
)
|
|
model_name = OLD_MODEL_SHORTCUTS[model]
|
|
compatibility = get_compatibility()
|
|
version = get_version(model_name, compatibility)
|
|
|
|
filename = get_model_filename(model_name, version, sdist)
|
|
|
|
download_model(filename, pip_args)
|
|
msg.good(
|
|
"Download and installation successful",
|
|
f"You can now load the package via spacy.load('{model_name}')",
|
|
)
|
|
if is_in_jupyter():
|
|
reload_deps_msg = (
|
|
"If you are in a Jupyter or Colab notebook, you may need to "
|
|
"restart Python in order to load all the package's dependencies. "
|
|
"You can do this by selecting the 'Restart kernel' or 'Restart "
|
|
"runtime' option."
|
|
)
|
|
msg.warn(
|
|
"Restart to reload dependencies",
|
|
reload_deps_msg,
|
|
)
|
|
elif is_in_interactive():
|
|
reload_deps_msg = (
|
|
"If you are in an interactive Python session, you may need to "
|
|
"exit and restart Python to load all the package's dependencies. "
|
|
"You can exit with Ctrl-D (or Ctrl-Z and Enter on Windows)."
|
|
)
|
|
msg.warn(
|
|
"Restart to reload dependencies",
|
|
reload_deps_msg,
|
|
)
|
|
|
|
|
|
def get_model_filename(model_name: str, version: str, sdist: bool = False) -> str:
|
|
dl_tpl = "{m}-{v}/{m}-{v}{s}"
|
|
suffix = SDIST_SUFFIX if sdist else WHEEL_SUFFIX
|
|
filename = dl_tpl.format(m=model_name, v=version, s=suffix)
|
|
return filename
|
|
|
|
|
|
def get_compatibility() -> dict:
|
|
if is_prerelease_version(about.__version__):
|
|
version: Optional[str] = about.__version__
|
|
else:
|
|
version = get_minor_version(about.__version__)
|
|
r = requests.get(about.__compatibility__)
|
|
if r.status_code != 200:
|
|
msg.fail(
|
|
f"Server error ({r.status_code})",
|
|
f"Couldn't fetch compatibility table. Please find a package for your spaCy "
|
|
f"installation (v{about.__version__}), and download it manually. "
|
|
f"For more details, see the documentation: "
|
|
f"https://spacy.io/usage/models",
|
|
exits=1,
|
|
)
|
|
comp_table = r.json()
|
|
comp = comp_table["spacy"]
|
|
if version not in comp:
|
|
msg.fail(f"No compatible packages found for v{version} of spaCy", exits=1)
|
|
return comp[version]
|
|
|
|
|
|
def get_version(model: str, comp: dict) -> str:
|
|
if model not in comp:
|
|
msg.fail(
|
|
f"No compatible package found for '{model}' (spaCy v{about.__version__})",
|
|
exits=1,
|
|
)
|
|
return comp[model][0]
|
|
|
|
|
|
def get_latest_version(model: str) -> str:
|
|
comp = get_compatibility()
|
|
return get_version(model, comp)
|
|
|
|
|
|
def download_model(
|
|
filename: str, user_pip_args: Optional[Sequence[str]] = None
|
|
) -> None:
|
|
# Construct the download URL carefully. We need to make sure we don't
|
|
# allow relative paths or other shenanigans to trick us into download
|
|
# from outside our own repo.
|
|
base_url = about.__download_url__
|
|
# urljoin requires that the path ends with /, or the last path part will be dropped
|
|
if not base_url.endswith("/"):
|
|
base_url = about.__download_url__ + "/"
|
|
download_url = urljoin(base_url, filename)
|
|
if not download_url.startswith(about.__download_url__):
|
|
raise ValueError(f"Download from {filename} rejected. Was it a relative path?")
|
|
pip_args = list(user_pip_args) if user_pip_args is not None else []
|
|
cmd = [sys.executable, "-m", "pip", "install"] + pip_args + [download_url]
|
|
run_command(cmd)
|