diff --git a/spacy/__init__.py b/spacy/__init__.py index 70e72b7a1..676659fdd 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -1,8 +1,15 @@ -from . import util +from .util import set_lang_class, get_lang_class, get_package, get_package_by_name + from .en import English +from .de import German + + +set_lang_class(English.lang, English) +set_lang_class(German.lang, German) def load(name, vectors=None, via=None): - return English( - package=util.get_package_by_name(name, via=via), - vectors_package=util.get_package_by_name(vectors, via=via)) + package = get_package_by_name(name, via=via) + vectors_package = get_package_by_name(vectors, via=via) + cls = get_lang_class(name) + return cls(package=package, vectors_package=vectors_package) diff --git a/spacy/about.py b/spacy/about.py index 3814b8d61..7f889cad8 100644 --- a/spacy/about.py +++ b/spacy/about.py @@ -10,4 +10,8 @@ __uri__ = 'https://spacy.io' __author__ = 'Matthew Honnibal' __email__ = 'matt@spacy.io' __license__ = 'MIT' -__default_model__ = 'en>=1.0.0,<1.1.0' +__models__ = { + 'en': 'en>=1.0.0,<1.1.0', + 'de': 'de>=1.0.0,<1.1.0', +} +__default_lang__ = 'en' diff --git a/spacy/de/download.py b/spacy/de/download.py new file mode 100644 index 000000000..ba57c1d31 --- /dev/null +++ b/spacy/de/download.py @@ -0,0 +1,13 @@ +import plac +from ..download import download + + +@plac.annotations( + force=("Force overwrite", "flag", "f", bool), +) +def main(data_size='all', force=False): + download('de', force) + + +if __name__ == '__main__': + plac.call(main) diff --git a/spacy/download.py b/spacy/download.py new file mode 100644 index 000000000..f7fc798ae --- /dev/null +++ b/spacy/download.py @@ -0,0 +1,33 @@ +from __future__ import print_function + +import sys + +import sputnik +from sputnik.package_list import (PackageNotFoundException, + CompatiblePackageNotFoundException) + +from . import about + + +def download(lang, force=False): + if force: + sputnik.purge(about.__title__, about.__version__) + + try: + sputnik.package(about.__title__, about.__version__, about.__models__[lang]) + print("Model already installed. Please run 'python -m " + "spacy.%s.download --force' to reinstall." % lang, file=sys.stderr) + sys.exit(1) + except (PackageNotFoundException, CompatiblePackageNotFoundException): + pass + + package = sputnik.install(about.__title__, about.__version__, about.__models__[lang]) + + try: + sputnik.package(about.__title__, about.__version__, about.__models__[lang]) + except (PackageNotFoundException, CompatiblePackageNotFoundException): + print("Model failed to install. Please run 'python -m " + "spacy.%s.download --force'." % lang, file=sys.stderr) + sys.exit(1) + + print("Model successfully installed.", file=sys.stderr) diff --git a/spacy/en/download.py b/spacy/en/download.py index 993b8b16d..f0c23b088 100644 --- a/spacy/en/download.py +++ b/spacy/en/download.py @@ -1,57 +1,12 @@ -from __future__ import print_function - -import sys -import os -import shutil - import plac -import sputnik -from sputnik.package_list import (PackageNotFoundException, - CompatiblePackageNotFoundException) - -from .. import about - - -def migrate(path): - data_path = os.path.join(path, 'data') - if os.path.isdir(data_path): - if os.path.islink(data_path): - os.unlink(data_path) - else: - shutil.rmtree(data_path) - for filename in os.listdir(path): - if filename.endswith('.tgz'): - os.unlink(os.path.join(path, filename)) +from ..download import download @plac.annotations( force=("Force overwrite", "flag", "f", bool), ) def main(data_size='all', force=False): - if force: - sputnik.purge(about.__title__, about.__version__) - - try: - sputnik.package(about.__title__, about.__version__, about.__default_model__) - print("Model already installed. Please run 'python -m " - "spacy.en.download --force' to reinstall.", file=sys.stderr) - sys.exit(1) - except (PackageNotFoundException, CompatiblePackageNotFoundException): - pass - - package = sputnik.install(about.__title__, about.__version__, about.__default_model__) - - try: - sputnik.package(about.__title__, about.__version__, about.__default_model__) - except (PackageNotFoundException, CompatiblePackageNotFoundException): - print("Model failed to install. Please run 'python -m " - "spacy.en.download --force'.", file=sys.stderr) - sys.exit(1) - - # FIXME clean up old-style packages - migrate(os.path.dirname(os.path.abspath(__file__))) - - print("Model successfully installed.", file=sys.stderr) + download('en', force) if __name__ == '__main__': diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx index f8613fce8..44d627505 100644 --- a/spacy/tokenizer.pyx +++ b/spacy/tokenizer.pyx @@ -16,8 +16,7 @@ cimport cython from . import util from .tokens.doc cimport Doc -from .util import read_lang_data -from .util import get_package +from .util import read_lang_data, get_package cdef class Tokenizer: diff --git a/spacy/util.py b/spacy/util.py index bcc55c656..b1e93d08b 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -14,6 +14,21 @@ from . import about from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE +LANGUAGES = {} + + +def set_lang_class(name, cls): + global LANGUAGES + LANGUAGES[name] = cls + + +def get_lang_class(name): + lang = re.split('[^a-zA-Z0-9_]', name, 1)[0] + if lang not in LANGUAGES: + raise RuntimeError('Language not supported: %s' % lang) + return LANGUAGES[lang] + + def get_package(data_dir): if not isinstance(data_dir, six.string_types): raise RuntimeError('data_dir must be a string') @@ -21,17 +36,20 @@ def get_package(data_dir): def get_package_by_name(name=None, via=None): + package_name = name or about.__models__[about.__default_lang__] + lang = get_lang_class(package_name) try: return sputnik.package(about.__title__, about.__version__, - name or about.__default_model__, data_path=via) + package_name, data_path=via) except PackageNotFoundException as e: - raise RuntimeError("Model %s not installed. Please run 'python -m " - "spacy.en.download' to install latest compatible " - "model." % name) + raise RuntimeError("Model '%s' not installed. Please run 'python -m " + "%s.download' to install latest compatible " + "model." % (name, lang.__module__)) except CompatiblePackageNotFoundException as e: raise RuntimeError("Installed model is not compatible with spaCy " - "version. Please run 'python -m spacy.en.download " - "--force' to install latest compatible model.") + "version. Please run 'python -m %s.download " + "--force' to install latest compatible model." % + (lang.__module__)) def normalize_slice(length, start, stop, step=None): diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index f876bfefb..3712a7383 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -25,7 +25,6 @@ from . import attrs from . import symbols from cymem.cymem cimport Address -from . import util from .serialize.packer cimport Packer from .attrs cimport PROB