mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Let --data-path be specified when running download.py scripts
Resolves https://github.com/explosion/spaCy/issues/637
This commit is contained in:
parent
c413ffe26d
commit
3871007c72
|
@ -4,9 +4,10 @@ from ..download import download
|
|||
|
||||
@plac.annotations(
|
||||
force=("Force overwrite", "flag", "f", bool),
|
||||
data_path=("Path to download model", "option", "d", str)
|
||||
)
|
||||
def main(data_size='all', force=False):
|
||||
download('de', force)
|
||||
def main(data_size='all', force=False, data_path=None):
|
||||
download('de', force=force, data_path=data_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -10,10 +10,19 @@ from . import about
|
|||
from . import util
|
||||
|
||||
|
||||
def download(lang, force=False, fail_on_exist=True):
|
||||
def download(lang, force=False, fail_on_exist=True, data_path=None):
|
||||
if not data_path:
|
||||
data_path = util.get_data_path()
|
||||
|
||||
# spaCy uses pathlib, and util.get_data_path returns a pathlib.Path object,
|
||||
# but sputnik (which we're using below) doesn't use pathlib and requires
|
||||
# its data_path parameters to be strings, so we coerce the data_path to a
|
||||
# str here.
|
||||
data_path = str(data_path)
|
||||
|
||||
try:
|
||||
pkg = sputnik.package(about.__title__, about.__version__,
|
||||
about.__models__.get(lang, lang))
|
||||
about.__models__.get(lang, lang), data_path)
|
||||
if force:
|
||||
shutil.rmtree(pkg.path)
|
||||
elif fail_on_exist:
|
||||
|
@ -24,15 +33,14 @@ def download(lang, force=False, fail_on_exist=True):
|
|||
pass
|
||||
|
||||
package = sputnik.install(about.__title__, about.__version__,
|
||||
about.__models__.get(lang, lang))
|
||||
about.__models__.get(lang, lang), data_path)
|
||||
|
||||
try:
|
||||
sputnik.package(about.__title__, about.__version__,
|
||||
about.__models__.get(lang, lang))
|
||||
about.__models__.get(lang, lang), data_path)
|
||||
except (PackageNotFoundException, CompatiblePackageNotFoundException):
|
||||
print("Model failed to install. Please run 'python -m "
|
||||
"spacy.%s.download --force'." % lang, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
data_path = util.get_data_path()
|
||||
print("Model successfully installed to %s" % data_path, file=sys.stderr)
|
||||
|
|
|
@ -7,17 +7,18 @@ from .. import about
|
|||
|
||||
@plac.annotations(
|
||||
force=("Force overwrite", "flag", "f", bool),
|
||||
data_path=("Path to download model", "option", "d", str)
|
||||
)
|
||||
def main(data_size='all', force=False):
|
||||
def main(data_size='all', force=False, data_path=None):
|
||||
if force:
|
||||
sputnik.purge(about.__title__, about.__version__)
|
||||
|
||||
if data_size in ('all', 'parser'):
|
||||
print("Downloading parsing model")
|
||||
download('en', False)
|
||||
download('en', force=False, data_path=data_path)
|
||||
if data_size in ('all', 'glove'):
|
||||
print("Downloading GloVe vectors")
|
||||
download('en_glove_cc_300_1m_vectors', False)
|
||||
download('en_glove_cc_300_1m_vectors', force=False, data_path=data_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Loading…
Reference in New Issue
Block a user