mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Let --data-path be specified when running download.py scripts
Resolves https://github.com/explosion/spaCy/issues/637
This commit is contained in:
		
							parent
							
								
									c413ffe26d
								
							
						
					
					
						commit
						3871007c72
					
				| 
						 | 
				
			
			@ -4,9 +4,10 @@ from ..download import download
 | 
			
		|||
 | 
			
		||||
@plac.annotations(
 | 
			
		||||
    force=("Force overwrite", "flag", "f", bool),
 | 
			
		||||
    data_path=("Path to download model", "option", "d", str)
 | 
			
		||||
)
 | 
			
		||||
def main(data_size='all', force=False):
 | 
			
		||||
    download('de', force)
 | 
			
		||||
def main(data_size='all', force=False, data_path=None):
 | 
			
		||||
    download('de', force=force, data_path=data_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -10,10 +10,19 @@ from . import about
 | 
			
		|||
from . import util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def download(lang, force=False, fail_on_exist=True):
 | 
			
		||||
def download(lang, force=False, fail_on_exist=True, data_path=None):
 | 
			
		||||
    if not data_path:
 | 
			
		||||
        data_path = util.get_data_path()
 | 
			
		||||
 | 
			
		||||
    # spaCy uses pathlib, and util.get_data_path returns a pathlib.Path object,
 | 
			
		||||
    # but sputnik (which we're using below) doesn't use pathlib and requires
 | 
			
		||||
    # its data_path parameters to be strings, so we coerce the data_path to a
 | 
			
		||||
    # str here.
 | 
			
		||||
    data_path = str(data_path)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        pkg = sputnik.package(about.__title__, about.__version__,
 | 
			
		||||
                        about.__models__.get(lang, lang))
 | 
			
		||||
                        about.__models__.get(lang, lang), data_path)
 | 
			
		||||
        if force:
 | 
			
		||||
            shutil.rmtree(pkg.path)
 | 
			
		||||
        elif fail_on_exist:
 | 
			
		||||
| 
						 | 
				
			
			@ -24,15 +33,14 @@ def download(lang, force=False, fail_on_exist=True):
 | 
			
		|||
        pass
 | 
			
		||||
 | 
			
		||||
    package = sputnik.install(about.__title__, about.__version__,
 | 
			
		||||
                              about.__models__.get(lang, lang))
 | 
			
		||||
                              about.__models__.get(lang, lang), data_path)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        sputnik.package(about.__title__, about.__version__,
 | 
			
		||||
                        about.__models__.get(lang, lang))
 | 
			
		||||
                        about.__models__.get(lang, lang), data_path)
 | 
			
		||||
    except (PackageNotFoundException, CompatiblePackageNotFoundException):
 | 
			
		||||
        print("Model failed to install. Please run 'python -m "
 | 
			
		||||
              "spacy.%s.download --force'." % lang, file=sys.stderr)
 | 
			
		||||
        sys.exit(1)
 | 
			
		||||
 | 
			
		||||
    data_path = util.get_data_path()
 | 
			
		||||
    print("Model successfully installed to %s" % data_path, file=sys.stderr)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,17 +7,18 @@ from .. import about
 | 
			
		|||
 | 
			
		||||
@plac.annotations(
 | 
			
		||||
    force=("Force overwrite", "flag", "f", bool),
 | 
			
		||||
    data_path=("Path to download model", "option", "d", str)
 | 
			
		||||
)
 | 
			
		||||
def main(data_size='all', force=False):
 | 
			
		||||
def main(data_size='all', force=False, data_path=None):
 | 
			
		||||
    if force:
 | 
			
		||||
        sputnik.purge(about.__title__, about.__version__)
 | 
			
		||||
 | 
			
		||||
    if data_size in ('all', 'parser'):
 | 
			
		||||
        print("Downloading parsing model")
 | 
			
		||||
        download('en', False)
 | 
			
		||||
        download('en', force=False, data_path=data_path)
 | 
			
		||||
    if data_size in ('all', 'glove'):
 | 
			
		||||
        print("Downloading GloVe vectors")
 | 
			
		||||
        download('en_glove_cc_300_1m_vectors', False)
 | 
			
		||||
        download('en_glove_cc_300_1m_vectors', force=False, data_path=data_path)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == '__main__':
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user