Add function to resolve model names and link them

This commit is contained in:
ines 2017-03-17 18:47:05 +01:00
parent 76c0ea6cc6
commit aedefef49d
2 changed files with 25 additions and 2 deletions

View File

@ -4,6 +4,7 @@ from __future__ import unicode_literals, print_function
import json import json
from pathlib import Path from pathlib import Path
from .util import set_lang_class, get_lang_class, parse_package_meta from .util import set_lang_class, get_lang_class, parse_package_meta
from .deprecated import resolve_model_name
from . import en from . import en
from . import de from . import de
@ -35,11 +36,12 @@ set_lang_class(bn.Bengali.lang, bn.Bengali)
def load(name, **overrides): def load(name, **overrides):
data_path = overrides.get('path', util.get_data_path()) data_path = overrides.get('path', util.get_data_path())
meta = parse_package_meta(data_path, name) model_name = resolve_model_name(name)
meta = parse_package_meta(data_path, model_name)
lang = meta['lang'] if meta and 'lang' in meta else 'en' lang = meta['lang'] if meta and 'lang' in meta else 'en'
cls = get_lang_class(lang) cls = get_lang_class(lang)
overrides['meta'] = meta overrides['meta'] = meta
overrides['path'] = Path(data_path / name) overrides['path'] = Path(data_path / model_name)
return cls(**overrides) return cls(**overrides)

View File

@ -2,6 +2,7 @@ from pathlib import Path
from . import about from . import about
from . import util from . import util
from .download import download from .download import download
from .link import link
try: try:
@ -86,6 +87,26 @@ def fix_glove_vectors_loading(overrides):
return overrides return overrides
def resolve_model_name(name):
"""If spaCy is loaded with 'en' or 'de', check if symlink already exists. If
not, user have upgraded from older version and have old models installed.
Check if old model directory exists and if so, return that instead and create
shortcut link.
"""
if name == 'en' or name == 'de':
versions = ['1.0.0', '1.1.0']
data_path = Path(util.get_data_path())
model_path = data_path / name
v_model_paths = [data_path / Path(name + '-' + v) for v in versions]
if not model_path.exists():
for v_path in v_model_paths:
if v_path.exists():
link(v_path, name)
return name
return name
class ModelDownload(): class ModelDownload():
"""Replace download modules within en and de with deprecation warning and """Replace download modules within en and de with deprecation warning and
download default language model (using shortcut). Use classmethods to allow download default language model (using shortcut). Use classmethods to allow