mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Merge pull request #305 from henningpeters/master
multiple langs in download script
This commit is contained in:
commit
8c77a994c6
|
@ -1,8 +1,15 @@
|
|||
from . import util
|
||||
from .util import set_lang_class, get_lang_class, get_package, get_package_by_name
|
||||
|
||||
from .en import English
|
||||
from .de import German
|
||||
|
||||
|
||||
set_lang_class(English.lang, English)
|
||||
set_lang_class(German.lang, German)
|
||||
|
||||
|
||||
def load(name, vectors=None, via=None):
|
||||
return English(
|
||||
package=util.get_package_by_name(name, via=via),
|
||||
vectors_package=util.get_package_by_name(vectors, via=via))
|
||||
package = get_package_by_name(name, via=via)
|
||||
vectors_package = get_package_by_name(vectors, via=via)
|
||||
cls = get_lang_class(name)
|
||||
return cls(package=package, vectors_package=vectors_package)
|
||||
|
|
|
@ -10,4 +10,8 @@ __uri__ = 'https://spacy.io'
|
|||
__author__ = 'Matthew Honnibal'
|
||||
__email__ = 'matt@spacy.io'
|
||||
__license__ = 'MIT'
|
||||
__default_model__ = 'en>=1.0.0,<1.1.0'
|
||||
__models__ = {
|
||||
'en': 'en>=1.0.0,<1.1.0',
|
||||
'de': 'de>=1.0.0,<1.1.0',
|
||||
}
|
||||
__default_lang__ = 'en'
|
||||
|
|
13
spacy/de/download.py
Normal file
13
spacy/de/download.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
import plac
|
||||
from ..download import download
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
force=("Force overwrite", "flag", "f", bool),
|
||||
)
|
||||
def main(data_size='all', force=False):
|
||||
download('de', force)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
plac.call(main)
|
33
spacy/download.py
Normal file
33
spacy/download.py
Normal file
|
@ -0,0 +1,33 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
|
||||
import sputnik
|
||||
from sputnik.package_list import (PackageNotFoundException,
|
||||
CompatiblePackageNotFoundException)
|
||||
|
||||
from . import about
|
||||
|
||||
|
||||
def download(lang, force=False):
|
||||
if force:
|
||||
sputnik.purge(about.__title__, about.__version__)
|
||||
|
||||
try:
|
||||
sputnik.package(about.__title__, about.__version__, about.__models__[lang])
|
||||
print("Model already installed. Please run 'python -m "
|
||||
"spacy.%s.download --force' to reinstall." % lang, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except (PackageNotFoundException, CompatiblePackageNotFoundException):
|
||||
pass
|
||||
|
||||
package = sputnik.install(about.__title__, about.__version__, about.__models__[lang])
|
||||
|
||||
try:
|
||||
sputnik.package(about.__title__, about.__version__, about.__models__[lang])
|
||||
except (PackageNotFoundException, CompatiblePackageNotFoundException):
|
||||
print("Model failed to install. Please run 'python -m "
|
||||
"spacy.%s.download --force'." % lang, file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
print("Model successfully installed.", file=sys.stderr)
|
|
@ -1,57 +1,12 @@
|
|||
from __future__ import print_function
|
||||
|
||||
import sys
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import plac
|
||||
import sputnik
|
||||
from sputnik.package_list import (PackageNotFoundException,
|
||||
CompatiblePackageNotFoundException)
|
||||
|
||||
from .. import about
|
||||
|
||||
|
||||
def migrate(path):
|
||||
data_path = os.path.join(path, 'data')
|
||||
if os.path.isdir(data_path):
|
||||
if os.path.islink(data_path):
|
||||
os.unlink(data_path)
|
||||
else:
|
||||
shutil.rmtree(data_path)
|
||||
for filename in os.listdir(path):
|
||||
if filename.endswith('.tgz'):
|
||||
os.unlink(os.path.join(path, filename))
|
||||
from ..download import download
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
force=("Force overwrite", "flag", "f", bool),
|
||||
)
|
||||
def main(data_size='all', force=False):
|
||||
if force:
|
||||
sputnik.purge(about.__title__, about.__version__)
|
||||
|
||||
try:
|
||||
sputnik.package(about.__title__, about.__version__, about.__default_model__)
|
||||
print("Model already installed. Please run 'python -m "
|
||||
"spacy.en.download --force' to reinstall.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except (PackageNotFoundException, CompatiblePackageNotFoundException):
|
||||
pass
|
||||
|
||||
package = sputnik.install(about.__title__, about.__version__, about.__default_model__)
|
||||
|
||||
try:
|
||||
sputnik.package(about.__title__, about.__version__, about.__default_model__)
|
||||
except (PackageNotFoundException, CompatiblePackageNotFoundException):
|
||||
print("Model failed to install. Please run 'python -m "
|
||||
"spacy.en.download --force'.", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# FIXME clean up old-style packages
|
||||
migrate(os.path.dirname(os.path.abspath(__file__)))
|
||||
|
||||
print("Model successfully installed.", file=sys.stderr)
|
||||
download('en', force)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
@ -16,8 +16,7 @@ cimport cython
|
|||
|
||||
from . import util
|
||||
from .tokens.doc cimport Doc
|
||||
from .util import read_lang_data
|
||||
from .util import get_package
|
||||
from .util import read_lang_data, get_package
|
||||
|
||||
|
||||
cdef class Tokenizer:
|
||||
|
|
|
@ -14,6 +14,21 @@ from . import about
|
|||
from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
|
||||
|
||||
|
||||
LANGUAGES = {}
|
||||
|
||||
|
||||
def set_lang_class(name, cls):
|
||||
global LANGUAGES
|
||||
LANGUAGES[name] = cls
|
||||
|
||||
|
||||
def get_lang_class(name):
|
||||
lang = re.split('[^a-zA-Z0-9_]', name, 1)[0]
|
||||
if lang not in LANGUAGES:
|
||||
raise RuntimeError('Language not supported: %s' % lang)
|
||||
return LANGUAGES[lang]
|
||||
|
||||
|
||||
def get_package(data_dir):
|
||||
if not isinstance(data_dir, six.string_types):
|
||||
raise RuntimeError('data_dir must be a string')
|
||||
|
@ -21,17 +36,20 @@ def get_package(data_dir):
|
|||
|
||||
|
||||
def get_package_by_name(name=None, via=None):
|
||||
package_name = name or about.__models__[about.__default_lang__]
|
||||
lang = get_lang_class(package_name)
|
||||
try:
|
||||
return sputnik.package(about.__title__, about.__version__,
|
||||
name or about.__default_model__, data_path=via)
|
||||
package_name, data_path=via)
|
||||
except PackageNotFoundException as e:
|
||||
raise RuntimeError("Model %s not installed. Please run 'python -m "
|
||||
"spacy.en.download' to install latest compatible "
|
||||
"model." % name)
|
||||
raise RuntimeError("Model '%s' not installed. Please run 'python -m "
|
||||
"%s.download' to install latest compatible "
|
||||
"model." % (name, lang.__module__))
|
||||
except CompatiblePackageNotFoundException as e:
|
||||
raise RuntimeError("Installed model is not compatible with spaCy "
|
||||
"version. Please run 'python -m spacy.en.download "
|
||||
"--force' to install latest compatible model.")
|
||||
"version. Please run 'python -m %s.download "
|
||||
"--force' to install latest compatible model." %
|
||||
(lang.__module__))
|
||||
|
||||
|
||||
def normalize_slice(length, start, stop, step=None):
|
||||
|
|
|
@ -25,7 +25,6 @@ from . import attrs
|
|||
from . import symbols
|
||||
|
||||
from cymem.cymem cimport Address
|
||||
from . import util
|
||||
from .serialize.packer cimport Packer
|
||||
from .attrs cimport PROB
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user