mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Remove symlinks, data dir and related stuff
This commit is contained in:
parent
2ed49404e3
commit
09cbeaef27
|
@ -1,6 +1,7 @@
|
|||
from wasabi import msg
|
||||
|
||||
from .download import download # noqa: F401
|
||||
from .info import info # noqa: F401
|
||||
from .link import link # noqa: F401
|
||||
from .package import package # noqa: F401
|
||||
from .profile import profile # noqa: F401
|
||||
from .train import train # noqa: F401
|
||||
|
@ -11,3 +12,10 @@ from .evaluate import evaluate # noqa: F401
|
|||
from .convert import convert # noqa: F401
|
||||
from .init_model import init_model # noqa: F401
|
||||
from .validate import validate # noqa: F401
|
||||
|
||||
|
||||
def link(*args, **kwargs):
|
||||
msg.warn(
|
||||
"As of spaCy v3.0, model symlinks are deprecated. You can load models "
|
||||
"using their full names or from a directory path."
|
||||
)
|
||||
|
|
|
@ -4,8 +4,6 @@ import subprocess
|
|||
import sys
|
||||
from wasabi import msg
|
||||
|
||||
from .link import link
|
||||
from ..util import get_package_path
|
||||
from .. import about
|
||||
|
||||
|
||||
|
@ -15,9 +13,9 @@ def download(
|
|||
*pip_args: ("Additional arguments to be passed to `pip install` on model install"),
|
||||
):
|
||||
"""
|
||||
Download compatible model from default download path using pip. Model
|
||||
can be shortcut, model name or, if --direct flag is set, full model name
|
||||
with version. For direct downloads, the compatibility check will be skipped.
|
||||
Download compatible model from default download path using pip. If --direct
|
||||
flag is set, the command expects the full model name with version.
|
||||
For direct downloads, the compatibility check will be skipped.
|
||||
"""
|
||||
if not require_package("spacy") and "--no-deps" not in pip_args:
|
||||
msg.warn(
|
||||
|
@ -47,28 +45,6 @@ def download(
|
|||
"Download and installation successful",
|
||||
f"You can now load the model via spacy.load('{model_name}')",
|
||||
)
|
||||
# Only create symlink if the model is installed via a shortcut like 'en'.
|
||||
# There's no real advantage over an additional symlink for en_core_web_sm
|
||||
# and if anything, it's more error prone and causes more confusion.
|
||||
if model in shortcuts:
|
||||
try:
|
||||
# Get package path here because link uses
|
||||
# pip.get_installed_distributions() to check if model is a
|
||||
# package, which fails if model was just installed via
|
||||
# subprocess
|
||||
package_path = get_package_path(model_name)
|
||||
link(model_name, model, force=True, model_path=package_path)
|
||||
except: # noqa: E722
|
||||
# Dirty, but since spacy.download and the auto-linking is
|
||||
# mostly a convenience wrapper, it's best to show a success
|
||||
# message and loading instructions, even if linking fails.
|
||||
msg.warn(
|
||||
"Download successful but linking failed",
|
||||
f"Creating a shortcut link for '{model}' didn't work (maybe you "
|
||||
f"don't have admin permissions?), but you can still load "
|
||||
f"the model via its full package name: "
|
||||
f"nlp = spacy.load('{model_name}')",
|
||||
)
|
||||
# If a model is downloaded and then loaded within the same process, our
|
||||
# is_package check currently fails, because pkg_resources.working_set
|
||||
# is not refreshed automatically (see #3923). We're trying to work
|
||||
|
@ -114,8 +90,7 @@ def get_version(model, comp):
|
|||
model = model.rsplit(".dev", 1)[0]
|
||||
if model not in comp:
|
||||
msg.fail(
|
||||
f"No compatible model found for '{model}' "
|
||||
f"(spaCy v{about.__version__}).",
|
||||
f"No compatible model found for '{model}' (spaCy v{about.__version__})",
|
||||
exits=1,
|
||||
)
|
||||
return comp[model][0]
|
||||
|
|
|
@ -3,25 +3,26 @@ from pathlib import Path
|
|||
from wasabi import msg
|
||||
import srsly
|
||||
|
||||
from .validate import get_model_pkgs
|
||||
from .. import util
|
||||
from .. import about
|
||||
|
||||
|
||||
def info(
|
||||
model: ("Optional shortcut link of model", "positional", None, str) = None,
|
||||
model: ("Optional model name", "positional", None, str) = None,
|
||||
markdown: ("Generate Markdown for GitHub issues", "flag", "md", str) = False,
|
||||
silent: ("Don't print anything (just return)", "flag", "s") = False,
|
||||
):
|
||||
"""
|
||||
Print info about spaCy installation. If a model shortcut link is
|
||||
speficied as an argument, print model information. Flag --markdown
|
||||
prints details in Markdown for easy copy-pasting to GitHub issues.
|
||||
Print info about spaCy installation. If a model is speficied as an argument,
|
||||
print model information. Flag --markdown prints details in Markdown for easy
|
||||
copy-pasting to GitHub issues.
|
||||
"""
|
||||
if model:
|
||||
if util.is_package(model):
|
||||
model_path = util.get_package_path(model)
|
||||
else:
|
||||
model_path = util.get_data_path() / model
|
||||
model_path = model
|
||||
meta_path = model_path / "meta.json"
|
||||
if not meta_path.is_file():
|
||||
msg.fail("Can't find model meta.json", meta_path, exits=1)
|
||||
|
@ -41,12 +42,13 @@ def info(
|
|||
else:
|
||||
msg.table(model_meta, title=title)
|
||||
return meta
|
||||
all_models, _ = get_model_pkgs()
|
||||
data = {
|
||||
"spaCy version": about.__version__,
|
||||
"Location": str(Path(__file__).parent.parent),
|
||||
"Platform": platform.platform(),
|
||||
"Python version": platform.python_version(),
|
||||
"Models": list_models(),
|
||||
"Models": ", ".join(model["name"] for model in all_models.values()),
|
||||
}
|
||||
if not silent:
|
||||
title = "Info about spaCy"
|
||||
|
@ -57,19 +59,6 @@ def info(
|
|||
return data
|
||||
|
||||
|
||||
def list_models():
|
||||
def exclude_dir(dir_name):
|
||||
# exclude common cache directories and hidden directories
|
||||
exclude = ("cache", "pycache", "__pycache__")
|
||||
return dir_name in exclude or dir_name.startswith(".")
|
||||
|
||||
data_path = util.get_data_path()
|
||||
if data_path:
|
||||
models = [f.parts[-1] for f in data_path.iterdir() if f.is_dir()]
|
||||
return ", ".join([m for m in models if not exclude_dir(m)])
|
||||
return "-"
|
||||
|
||||
|
||||
def print_markdown(data, title=None):
|
||||
"""Print data in GitHub-flavoured Markdown format for issues etc.
|
||||
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
from pathlib import Path
|
||||
from wasabi import msg
|
||||
|
||||
from ..compat import symlink_to
|
||||
from .. import util
|
||||
|
||||
|
||||
def link(
|
||||
origin: ("package name or local path to model", "positional", None, str),
|
||||
link_name: ("name of shortuct link to create", "positional", None, str),
|
||||
force: ("force overwriting of existing link", "flag", "f", bool) = False,
|
||||
model_path=None,
|
||||
):
|
||||
"""
|
||||
Create a symlink for models within the spacy/data directory. Accepts
|
||||
either the name of a pip package, or the local path to the model data
|
||||
directory. Linking models allows loading them via spacy.load(link_name).
|
||||
"""
|
||||
if util.is_package(origin):
|
||||
model_path = util.get_package_path(origin)
|
||||
else:
|
||||
model_path = Path(origin) if model_path is None else Path(model_path)
|
||||
if not model_path.exists():
|
||||
msg.fail(
|
||||
"Can't locate model data",
|
||||
f"The data should be located in {model_path}",
|
||||
exits=1,
|
||||
)
|
||||
data_path = util.get_data_path()
|
||||
if not data_path or not data_path.exists():
|
||||
spacy_loc = Path(__file__).parent.parent
|
||||
msg.fail(
|
||||
f"Can't find the spaCy data path to create model symlink",
|
||||
f"Make sure a directory `/data` exists within your spaCy "
|
||||
f"installation and try again. The data directory should be located "
|
||||
f"here: {spacy_loc}",
|
||||
exits=1,
|
||||
)
|
||||
link_path = util.get_data_path() / link_name
|
||||
if link_path.is_symlink() and not force:
|
||||
msg.fail(
|
||||
f"Link '{link_name}' already exists",
|
||||
"To overwrite an existing link, use the --force flag",
|
||||
exits=1,
|
||||
)
|
||||
elif link_path.is_symlink(): # does a symlink exist?
|
||||
# NB: It's important to check for is_symlink here and not for exists,
|
||||
# because invalid/outdated symlinks would return False otherwise.
|
||||
link_path.unlink()
|
||||
elif link_path.exists(): # does it exist otherwise?
|
||||
# NB: Check this last because valid symlinks also "exist".
|
||||
msg.fail(
|
||||
f"Can't overwrite symlink '{link_name}'",
|
||||
"This can happen if your data directory contains a directory or "
|
||||
"file of the same name.",
|
||||
exits=1,
|
||||
)
|
||||
details = f"{model_path} --> {link_path}"
|
||||
try:
|
||||
symlink_to(link_path, model_path)
|
||||
except: # noqa: E722
|
||||
# This is quite dirty, but just making sure other errors are caught.
|
||||
msg.fail(
|
||||
f"Couldn't link model to '{link_name}'",
|
||||
"Creating a symlink in spacy/data failed. Make sure you have the "
|
||||
"required permissions and try re-running the command as admin, or "
|
||||
"use a virtualenv. You can still import the model as a module and "
|
||||
"call its load() method, or create the symlink manually.",
|
||||
)
|
||||
msg.text(details)
|
||||
raise
|
||||
msg.good("Linking successful", details)
|
||||
msg.text(f"You can now load the model via spacy.load('{link_name}')")
|
|
@ -1,10 +1,8 @@
|
|||
from pathlib import Path
|
||||
import sys
|
||||
import requests
|
||||
import srsly
|
||||
from wasabi import msg
|
||||
|
||||
from ..util import get_data_path
|
||||
from .. import about
|
||||
|
||||
|
||||
|
@ -13,6 +11,50 @@ def validate():
|
|||
Validate that the currently installed version of spaCy is compatible
|
||||
with the installed models. Should be run after `pip install -U spacy`.
|
||||
"""
|
||||
model_pkgs, compat = get_model_pkgs()
|
||||
spacy_version = about.__version__.rsplit(".dev", 1)[0]
|
||||
current_compat = compat.get(spacy_version, {})
|
||||
if not current_compat:
|
||||
msg.warn(f"No compatible models found for v{spacy_version} of spaCy")
|
||||
incompat_models = {d["name"] for _, d in model_pkgs.items() if not d["compat"]}
|
||||
na_models = [m for m in incompat_models if m not in current_compat]
|
||||
update_models = [m for m in incompat_models if m in current_compat]
|
||||
spacy_dir = Path(__file__).parent.parent
|
||||
|
||||
msg.divider(f"Installed models (spaCy v{about.__version__})")
|
||||
msg.info(f"spaCy installation: {spacy_dir}")
|
||||
|
||||
if model_pkgs:
|
||||
header = ("NAME", "VERSION", "")
|
||||
rows = []
|
||||
for name, data in model_pkgs.items():
|
||||
if data["compat"]:
|
||||
comp = msg.text("", color="green", icon="good", no_print=True)
|
||||
version = msg.text(data["version"], color="green", no_print=True)
|
||||
else:
|
||||
version = msg.text(data["version"], color="red", no_print=True)
|
||||
comp = f"--> {compat.get(data['name'], ['n/a'])[0]}"
|
||||
rows.append((data["name"], version, comp))
|
||||
msg.table(rows, header=header)
|
||||
else:
|
||||
msg.text("No models found in your current environment.", exits=0)
|
||||
if update_models:
|
||||
msg.divider("Install updates")
|
||||
msg.text("Use the following commands to update the model packages:")
|
||||
cmd = "python -m spacy download {}"
|
||||
print("\n".join([cmd.format(pkg) for pkg in update_models]) + "\n")
|
||||
if na_models:
|
||||
msg.warn(
|
||||
f"The following models are not available for spaCy v{about.__version__}:",
|
||||
", ".join(na_models),
|
||||
)
|
||||
if incompat_models:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_model_pkgs():
|
||||
import pkg_resources
|
||||
|
||||
with msg.loading("Loading compatibility table..."):
|
||||
r = requests.get(about.__compatibility__)
|
||||
if r.status_code != 200:
|
||||
|
@ -23,88 +65,11 @@ def validate():
|
|||
)
|
||||
msg.good("Loaded compatibility table")
|
||||
compat = r.json()["spacy"]
|
||||
version = about.__version__
|
||||
version = version.rsplit(".dev", 1)[0]
|
||||
current_compat = compat.get(version)
|
||||
if not current_compat:
|
||||
msg.fail(
|
||||
f"Can't find spaCy v{version} in compatibility table",
|
||||
about.__compatibility__,
|
||||
exits=1,
|
||||
)
|
||||
all_models = set()
|
||||
for spacy_v, models in dict(compat).items():
|
||||
all_models.update(models.keys())
|
||||
for model, model_vs in models.items():
|
||||
compat[spacy_v][model] = [reformat_version(v) for v in model_vs]
|
||||
model_links = get_model_links(current_compat)
|
||||
model_pkgs = get_model_pkgs(current_compat, all_models)
|
||||
incompat_links = {l for l, d in model_links.items() if not d["compat"]}
|
||||
incompat_models = {d["name"] for _, d in model_pkgs.items() if not d["compat"]}
|
||||
incompat_models.update(
|
||||
[d["name"] for _, d in model_links.items() if not d["compat"]]
|
||||
)
|
||||
na_models = [m for m in incompat_models if m not in current_compat]
|
||||
update_models = [m for m in incompat_models if m in current_compat]
|
||||
spacy_dir = Path(__file__).parent.parent
|
||||
|
||||
msg.divider(f"Installed models (spaCy v{about.__version__})")
|
||||
msg.info(f"spaCy installation: {spacy_dir}")
|
||||
|
||||
if model_links or model_pkgs:
|
||||
header = ("TYPE", "NAME", "MODEL", "VERSION", "")
|
||||
rows = []
|
||||
for name, data in model_pkgs.items():
|
||||
rows.append(get_model_row(current_compat, name, data, msg))
|
||||
for name, data in model_links.items():
|
||||
rows.append(get_model_row(current_compat, name, data, msg, "link"))
|
||||
msg.table(rows, header=header)
|
||||
else:
|
||||
msg.text("No models found in your current environment.", exits=0)
|
||||
if update_models:
|
||||
msg.divider("Install updates")
|
||||
msg.text("Use the following commands to update the model packages:")
|
||||
cmd = "python -m spacy download {}"
|
||||
print("\n".join([cmd.format(pkg) for pkg in update_models]) + "\n")
|
||||
if na_models:
|
||||
msg.text(
|
||||
f"The following models are not available for spaCy "
|
||||
f"v{about.__version__}: {', '.join(na_models)}"
|
||||
)
|
||||
if incompat_links:
|
||||
msg.text(
|
||||
f"You may also want to overwrite the incompatible links using the "
|
||||
f"`python -m spacy link` command with `--force`, or remove them "
|
||||
f"from the data directory. "
|
||||
f"Data path: {get_data_path()}"
|
||||
)
|
||||
if incompat_models or incompat_links:
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def get_model_links(compat):
|
||||
links = {}
|
||||
data_path = get_data_path()
|
||||
if data_path:
|
||||
models = [p for p in data_path.iterdir() if is_model_path(p)]
|
||||
for model in models:
|
||||
meta_path = Path(model) / "meta.json"
|
||||
if not meta_path.exists():
|
||||
continue
|
||||
meta = srsly.read_json(meta_path)
|
||||
link = model.parts[-1]
|
||||
name = meta["lang"] + "_" + meta["name"]
|
||||
links[link] = {
|
||||
"name": name,
|
||||
"version": meta["version"],
|
||||
"compat": is_compat(compat, name, meta["version"]),
|
||||
}
|
||||
return links
|
||||
|
||||
|
||||
def get_model_pkgs(compat, all_models):
|
||||
import pkg_resources
|
||||
|
||||
pkgs = {}
|
||||
for pkg_name, pkg_data in pkg_resources.working_set.by_key.items():
|
||||
package = pkg_name.replace("-", "_")
|
||||
|
@ -113,29 +78,9 @@ def get_model_pkgs(compat, all_models):
|
|||
pkgs[pkg_name] = {
|
||||
"name": package,
|
||||
"version": version,
|
||||
"compat": is_compat(compat, package, version),
|
||||
"compat": package in compat and version in compat[package],
|
||||
}
|
||||
return pkgs
|
||||
|
||||
|
||||
def get_model_row(compat, name, data, msg, model_type="package"):
|
||||
if data["compat"]:
|
||||
comp = msg.text("", color="green", icon="good", no_print=True)
|
||||
version = msg.text(data["version"], color="green", no_print=True)
|
||||
else:
|
||||
version = msg.text(data["version"], color="red", no_print=True)
|
||||
comp = f"--> {compat.get(data['name'], ['n/a'])[0]}"
|
||||
return (model_type, name, data["name"], version, comp)
|
||||
|
||||
|
||||
def is_model_path(model_path):
|
||||
exclude = ["cache", "pycache", "__pycache__"]
|
||||
name = model_path.parts[-1]
|
||||
return model_path.is_dir() and name not in exclude and not name.startswith(".")
|
||||
|
||||
|
||||
def is_compat(compat, name, version):
|
||||
return name in compat and version in compat[name]
|
||||
return pkgs, compat
|
||||
|
||||
|
||||
def reformat_version(version):
|
||||
|
|
|
@ -5,7 +5,6 @@ e.g. `unicode_`.
|
|||
|
||||
DOCS: https://spacy.io/api/top-level#compat
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
|
||||
from thinc.util import copy_array
|
||||
|
@ -43,33 +42,6 @@ is_linux = sys.platform.startswith("linux")
|
|||
is_osx = sys.platform == "darwin"
|
||||
|
||||
|
||||
def symlink_to(orig, dest):
|
||||
"""Create a symlink. Used for model shortcut links.
|
||||
|
||||
orig (unicode / Path): The origin path.
|
||||
dest (unicode / Path): The destination path of the symlink.
|
||||
"""
|
||||
if is_windows:
|
||||
import subprocess
|
||||
|
||||
subprocess.check_call(["mklink", "/d", str(orig), str(dest)], shell=True)
|
||||
else:
|
||||
orig.symlink_to(dest)
|
||||
|
||||
|
||||
def symlink_remove(link):
|
||||
"""Remove a symlink. Used for model shortcut links.
|
||||
|
||||
link (unicode / Path): The path to the symlink.
|
||||
"""
|
||||
# https://stackoverflow.com/q/26554135/6400719
|
||||
if os.path.isdir(str(link)) and is_windows:
|
||||
# this should only be on Py2.7 and windows
|
||||
os.rmdir(str(link))
|
||||
else:
|
||||
os.unlink(str(link))
|
||||
|
||||
|
||||
def is_config(windows=None, linux=None, osx=None, **kwargs):
|
||||
"""Check if a specific configuration of Python version and operating system
|
||||
matches the user's setup. Mostly used to display targeted error messages.
|
||||
|
|
|
@ -224,13 +224,8 @@ class Errors(object):
|
|||
E047 = ("Can't assign a value to unregistered extension attribute "
|
||||
"'{name}'. Did you forget to call the `set_extension` method?")
|
||||
E048 = ("Can't import language {lang} from spacy.lang: {err}")
|
||||
E049 = ("Can't find spaCy data directory: '{path}'. Check your "
|
||||
"installation and permissions, or use spacy.util.set_data_path "
|
||||
"to customise the location if necessary.")
|
||||
E050 = ("Can't find model '{name}'. It doesn't seem to be a shortcut "
|
||||
"link, a Python package or a valid path to a data directory.")
|
||||
E051 = ("Cant' load '{name}'. If you're using a shortcut link, make sure "
|
||||
"it points to a valid package (not just a data directory).")
|
||||
E050 = ("Can't find model '{name}'. It doesn't seem to be a Python "
|
||||
"package or a valid path to a data directory.")
|
||||
E052 = ("Can't find model directory: {path}")
|
||||
E053 = ("Could not read meta.json from {path}")
|
||||
E054 = ("No valid '{setting}' setting found in model meta.json.")
|
||||
|
|
|
@ -4,36 +4,8 @@ import ctypes
|
|||
from pathlib import Path
|
||||
from spacy import util
|
||||
from spacy import prefer_gpu, require_gpu
|
||||
from spacy.compat import symlink_to, symlink_remove, is_windows
|
||||
from spacy.ml._layers import PrecomputableAffine
|
||||
from spacy.ml._layers import _backprop_precomputable_affine_padding
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def symlink_target():
|
||||
return Path("./foo-target")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def symlink():
|
||||
return Path("./foo-symlink")
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def symlink_setup_target(request, symlink_target, symlink):
|
||||
if not symlink_target.exists():
|
||||
os.mkdir(str(symlink_target))
|
||||
# yield -- need to cleanup even if assertion fails
|
||||
# https://github.com/pytest-dev/pytest/issues/2508#issuecomment-309934240
|
||||
|
||||
def cleanup():
|
||||
# Remove symlink only if it was created
|
||||
if symlink.exists():
|
||||
symlink_remove(symlink)
|
||||
os.rmdir(str(symlink_target))
|
||||
|
||||
request.addfinalizer(cleanup)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -109,25 +81,6 @@ def test_require_gpu():
|
|||
require_gpu()
|
||||
|
||||
|
||||
def test_create_symlink_windows(
|
||||
symlink_setup_target, symlink_target, symlink, is_admin
|
||||
):
|
||||
"""Test the creation of symlinks on windows. If run as admin or not on windows it should succeed, otherwise a CalledProcessError should be raised."""
|
||||
assert symlink_target.exists()
|
||||
|
||||
if is_admin or not is_windows:
|
||||
try:
|
||||
symlink_to(symlink, symlink_target)
|
||||
assert symlink.exists()
|
||||
except CalledProcessError as e:
|
||||
pytest.fail(e)
|
||||
else:
|
||||
with pytest.raises(CalledProcessError):
|
||||
symlink_to(symlink, symlink_target)
|
||||
|
||||
assert not symlink.exists()
|
||||
|
||||
|
||||
def test_ascii_filenames():
|
||||
"""Test that all filenames in the project are ASCII.
|
||||
See: https://twitter.com/_inesmontani/status/1177941471632211968
|
||||
|
|
|
@ -29,7 +29,6 @@ from .symbols import ORTH
|
|||
from .compat import cupy, CudaStream
|
||||
from .errors import Errors, Warnings, deprecation_warning, user_warning
|
||||
|
||||
_data_path = Path(__file__).parent / "data"
|
||||
_PRINT_ENV = False
|
||||
|
||||
|
||||
|
@ -84,27 +83,6 @@ def set_lang_class(name, cls):
|
|||
registry.languages.register(name, func=cls)
|
||||
|
||||
|
||||
def get_data_path(require_exists=True):
|
||||
"""Get path to spaCy data directory.
|
||||
|
||||
require_exists (bool): Only return path if it exists, otherwise None.
|
||||
RETURNS (Path or None): Data path or None.
|
||||
"""
|
||||
if not require_exists:
|
||||
return _data_path
|
||||
else:
|
||||
return _data_path if _data_path.exists() else None
|
||||
|
||||
|
||||
def set_data_path(path):
|
||||
"""Set path to spaCy data directory.
|
||||
|
||||
path (unicode or Path): Path to new data directory.
|
||||
"""
|
||||
global _data_path
|
||||
_data_path = ensure_path(path)
|
||||
|
||||
|
||||
def make_layer(arch_config):
|
||||
arch_func = registry.architectures.get(arch_config["arch"])
|
||||
return arch_func(arch_config["config"])
|
||||
|
@ -145,18 +123,13 @@ def get_module_path(module):
|
|||
|
||||
|
||||
def load_model(name, **overrides):
|
||||
"""Load a model from a shortcut link, package or data path.
|
||||
"""Load a model from a package or data path.
|
||||
|
||||
name (unicode): Package name, shortcut link or model path.
|
||||
name (unicode): Package name or model path.
|
||||
**overrides: Specific overrides, like pipeline components to disable.
|
||||
RETURNS (Language): `Language` class with the loaded model.
|
||||
"""
|
||||
data_path = get_data_path()
|
||||
if not data_path or not data_path.exists():
|
||||
raise IOError(Errors.E049.format(path=data_path))
|
||||
if isinstance(name, str): # in data dir / shortcut
|
||||
if name in set([d.name for d in data_path.iterdir()]):
|
||||
return load_model_from_link(name, **overrides)
|
||||
if isinstance(name, str): # name or string path
|
||||
if is_package(name): # installed as package
|
||||
return load_model_from_package(name, **overrides)
|
||||
if Path(name).exists(): # path to model data directory
|
||||
|
@ -166,16 +139,6 @@ def load_model(name, **overrides):
|
|||
raise IOError(Errors.E050.format(name=name))
|
||||
|
||||
|
||||
def load_model_from_link(name, **overrides):
|
||||
"""Load a model from a shortcut link, or directory in spaCy data path."""
|
||||
path = get_data_path() / name / "__init__.py"
|
||||
try:
|
||||
cls = import_file(name, path)
|
||||
except AttributeError:
|
||||
raise IOError(Errors.E051.format(name=name))
|
||||
return cls.load(**overrides)
|
||||
|
||||
|
||||
def load_model_from_package(name, **overrides):
|
||||
"""Load a model from an installed package."""
|
||||
cls = importlib.import_module(name)
|
||||
|
@ -797,5 +760,13 @@ def create_default_optimizer():
|
|||
eps = env_opt("optimizer_eps", 1e-8)
|
||||
L2 = env_opt("L2_penalty", 1e-6)
|
||||
grad_clip = env_opt("grad_norm_clip", 1.0)
|
||||
optimizer = Adam(learn_rate, L2=L2, beta1=beta1, beta2=beta2, eps=eps, ops=ops, grad_clip=grad_clip)
|
||||
optimizer = Adam(
|
||||
learn_rate,
|
||||
L2=L2,
|
||||
beta1=beta1,
|
||||
beta2=beta2,
|
||||
eps=eps,
|
||||
ops=ops,
|
||||
grad_clip=grad_clip,
|
||||
)
|
||||
return optimizer
|
||||
|
|
Loading…
Reference in New Issue
Block a user