mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Refactor download script and about.py to use new download method
This commit is contained in:
parent
f5d1a39a5b
commit
58b884b6d4
|
@ -1,5 +1,4 @@
|
|||
# inspired from:
|
||||
|
||||
# https://python-packaging-user-guide.readthedocs.org/en/latest/single_source_version/
|
||||
# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
|
||||
|
||||
|
@ -10,7 +9,8 @@ __uri__ = 'https://spacy.io'
|
|||
__author__ = 'Matthew Honnibal'
|
||||
__email__ = 'matt@explosion.ai'
|
||||
__license__ = 'MIT'
|
||||
__models__ = {
|
||||
'en': 'en>=1.1.0,<1.2.0',
|
||||
'de': 'de>=1.0.0,<1.1.0',
|
||||
}
|
||||
|
||||
__docs__ = 'https://spacy.io/docs/usage'
|
||||
__download_url__ = 'https://github.com/explosion/spacy-models/releases/download'
|
||||
__compatibility__ = 'https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json?token=ANAt54fi5zcUtnwGhMLw2klWwcAyHkZGks5Y0nw1wA%3D%3D'
|
||||
__shortcuts__ = {'en': 'en_core_web_md', 'de': 'de_core_web_md', 'vectors': 'en_vectors_glove_md'}
|
||||
|
|
|
@ -1,47 +1,79 @@
|
|||
from __future__ import print_function
|
||||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import shutil
|
||||
|
||||
import sputnik
|
||||
from sputnik.package_list import (PackageNotFoundException,
|
||||
CompatiblePackageNotFoundException)
|
||||
|
||||
import pip
|
||||
import plac
|
||||
import requests
|
||||
from os import path
|
||||
from . import about
|
||||
from . import util
|
||||
|
||||
|
||||
def download(lang, force=False, fail_on_exist=True, data_path=None):
|
||||
if not data_path:
|
||||
data_path = util.get_data_path(require_exists=False)
|
||||
@plac.annotations(
|
||||
model=("Model to download", "positional", None, str),
|
||||
direct=("Force direct download", "flag", "d", bool)
|
||||
)
|
||||
def download(model=None, direct=False):
|
||||
check_error_depr(model)
|
||||
|
||||
# spaCy uses pathlib, and util.get_data_path returns a pathlib.Path object,
|
||||
# but sputnik (which we're using below) doesn't use pathlib and requires
|
||||
# its data_path parameters to be strings, so we coerce the data_path to a
|
||||
# str here.
|
||||
data_path = str(data_path)
|
||||
if direct:
|
||||
download_model('{m}/{m}.tar.gz'.format(m=model))
|
||||
else:
|
||||
model = about.__shortcuts__[model] if model in about.__shortcuts__ else model
|
||||
compatibility = get_compatibility()
|
||||
version = get_version(model, compatibility)
|
||||
download_model('{m}-{v}/{m}-{v}.tar.gz'.format(m=model, v=version))
|
||||
|
||||
try:
|
||||
pkg = sputnik.package(about.__title__, about.__version__,
|
||||
about.__models__.get(lang, lang), data_path)
|
||||
if force:
|
||||
shutil.rmtree(pkg.path)
|
||||
elif fail_on_exist:
|
||||
print("Model already installed. Please run 'python -m "
|
||||
"spacy.%s.download --force' to reinstall." % lang, file=sys.stderr)
|
||||
sys.exit(0)
|
||||
except (PackageNotFoundException, CompatiblePackageNotFoundException):
|
||||
pass
|
||||
|
||||
package = sputnik.install(about.__title__, about.__version__,
|
||||
about.__models__.get(lang, lang), data_path)
|
||||
def get_compatibility():
|
||||
version = about.__version__
|
||||
r = requests.get(about.__compatibility__)
|
||||
if r.status_code != 200:
|
||||
exit("Couldn't fetch compatibility table. Please find the right model for "
|
||||
"your spaCy installation (v{v}), and download it manually:".format(v=version),
|
||||
"python -m spacy.download [full model name + version] --direct",
|
||||
title="Server error ({c})".format(c=r.status_code))
|
||||
|
||||
try:
|
||||
sputnik.package(about.__title__, about.__version__,
|
||||
about.__models__.get(lang, lang), data_path)
|
||||
except (PackageNotFoundException, CompatiblePackageNotFoundException):
|
||||
print("Model failed to install. Please run 'python -m "
|
||||
"spacy.%s.download --force'." % lang, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
comp = r.json()['spacy']
|
||||
if version not in comp:
|
||||
exit("No compatible models found for v{v} of spaCy.".format(v=version),
|
||||
title="Compatibility error")
|
||||
else:
|
||||
return comp[version]
|
||||
|
||||
print("Model successfully installed to %s" % data_path, file=sys.stderr)
|
||||
|
||||
def get_version(model, comp):
|
||||
if model not in comp:
|
||||
exit("No compatible model found for "
|
||||
"{m} (spaCy v{v}).".format(m=model, v=about.__version__),
|
||||
title="Compatibility error")
|
||||
return comp[model][0]
|
||||
|
||||
|
||||
def download_model(filename):
|
||||
util.print_msg("Downloading {f}".format(f=filename))
|
||||
download_url = path.join(about.__download_url__, filename)
|
||||
pip.main(['install', download_url])
|
||||
|
||||
|
||||
def check_error_depr(model):
|
||||
if not model:
|
||||
exit("python -m spacy.download [name or shortcut]",
|
||||
title="Missing model name or shortcut")
|
||||
|
||||
if model == 'all':
|
||||
exit("As of v1.7.0, the download all command is deprecated. Please "
|
||||
"download the models individually via spacy.download [model name] "
|
||||
"or pip install. For more info on this, see the "
|
||||
"documentation: {d}".format(d=about.__docs__),
|
||||
title="Deprecated command")
|
||||
|
||||
|
||||
def exit(*messages, **kwargs):
|
||||
util.print_msg(*messages, **kwargs)
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(download)
|
||||
|
|
Loading…
Reference in New Issue
Block a user