spaCy/spacy/cli/validate.py
Daniël de Kok e2b70df012
Configure isort to use the Black profile, recursively isort the spacy module (#12721)
* Use isort with Black profile

* isort all the things

* Fix import cycles as a result of import sorting

* Add DOCBIN_ALL_ATTRS type definition

* Add isort to requirements

* Remove isort from build dependencies check

* Typo
2023-06-14 17:48:41 +02:00

124 lines
4.5 KiB
Python

import sys
import warnings
from pathlib import Path
from typing import Tuple
import requests
from wasabi import Printer, msg
from .. import about
from ..util import (
get_installed_models,
get_minor_version,
get_model_meta,
get_package_path,
get_package_version,
is_compatible_version,
)
from ._util import app
@app.command("validate")
def validate_cli():
"""
Validate the currently installed pipeline packages and spaCy version. Checks
if the installed packages are compatible and shows upgrade instructions if
available. Should be run after `pip install -U spacy`.
DOCS: https://spacy.io/api/cli#validate
"""
validate()
def validate() -> None:
model_pkgs, compat = get_model_pkgs()
spacy_version = get_minor_version(about.__version__)
current_compat = compat.get(spacy_version, {})
if not current_compat:
msg.warn(f"No compatible packages found for v{spacy_version} of spaCy")
incompat_models = {d["name"] for _, d in model_pkgs.items() if not d["compat"]}
na_models = [m for m in incompat_models if m not in current_compat]
update_models = [m for m in incompat_models if m in current_compat]
spacy_dir = Path(__file__).parent.parent
msg.divider(f"Installed pipeline packages (spaCy v{about.__version__})")
msg.info(f"spaCy installation: {spacy_dir}")
if model_pkgs:
header = ("NAME", "SPACY", "VERSION", "")
rows = []
for name, data in model_pkgs.items():
if data["compat"]:
comp = msg.text("", color="green", icon="good", no_print=True)
version = msg.text(data["version"], color="green", no_print=True)
else:
version = msg.text(data["version"], color="yellow", no_print=True)
comp = f"--> {current_compat.get(data['name'], ['n/a'])[0]}"
rows.append((data["name"], data["spacy"], version, comp))
msg.table(rows, header=header)
else:
msg.text("No pipeline packages found in your current environment.", exits=0)
if update_models:
msg.divider("Install updates")
msg.text("Use the following commands to update the packages:")
cmd = "python -m spacy download {}"
print("\n".join([cmd.format(pkg) for pkg in update_models]) + "\n")
if na_models:
msg.info(
f"The following packages are custom spaCy pipelines or not "
f"available for spaCy v{about.__version__}:",
", ".join(na_models),
)
if incompat_models:
sys.exit(1)
def get_model_pkgs(silent: bool = False) -> Tuple[dict, dict]:
msg = Printer(no_print=silent, pretty=not silent)
with msg.loading("Loading compatibility table..."):
r = requests.get(about.__compatibility__)
if r.status_code != 200:
msg.fail(
f"Server error ({r.status_code})",
"Couldn't fetch compatibility table.",
exits=1,
)
msg.good("Loaded compatibility table")
compat = r.json()["spacy"]
all_models = set()
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="\\[W09[45]")
installed_models = get_installed_models()
for spacy_v, models in dict(compat).items():
all_models.update(models.keys())
for model, model_vs in models.items():
compat[spacy_v][model] = [reformat_version(v) for v in model_vs]
pkgs = {}
for pkg_name in installed_models:
package = pkg_name.replace("-", "_")
version = get_package_version(pkg_name)
if package in compat:
is_compat = version in compat[package]
spacy_version = about.__version__
else:
model_path = get_package_path(package)
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="\\[W09[45]")
model_meta = get_model_meta(model_path)
spacy_version = model_meta.get("spacy_version", "n/a")
is_compat = is_compatible_version(about.__version__, spacy_version) # type: ignore[assignment]
pkgs[pkg_name] = {
"name": package,
"version": version,
"spacy": spacy_version,
"compat": is_compat,
}
return pkgs, compat
def reformat_version(version: str) -> str:
"""Hack to reformat old versions ending on '-alpha' to match pip format."""
if version.endswith("-alpha"):
return version.replace("-alpha", "a0")
return version.replace("-alpha", "a")