allow to specify version constraint within model name

This commit is contained in:
Henning Peters 2015-12-18 19:12:08 +01:00
parent d1f46528ca
commit d8d348bb55
2 changed files with 5 additions and 15 deletions

View File

@ -167,15 +167,14 @@ class Language(object):
3) by a model name/version (and optionally a package root dir) 3) by a model name/version (and optionally a package root dir)
- Language(model='en_default') - Language(model='en_default')
- Language(model='en_default', version='1.0.0') - Language(model='en_default ==1.0.0')
- Language(model='en_default', version='1.0.0', data_dir='spacy/data') - Language(model='en_default <1.1.0, data_dir='spacy/data')
""" """
data_dir = kwargs.pop('data_dir', None) data_dir = kwargs.pop('data_dir', None)
lang = kwargs.pop('lang', None) lang = kwargs.pop('lang', None)
model = kwargs.pop('model', None) model = kwargs.pop('model', None)
version = kwargs.pop('version', None)
vocab = kwargs.pop('vocab', None) vocab = kwargs.pop('vocab', None)
tokenizer = kwargs.pop('tokenizer', None) tokenizer = kwargs.pop('tokenizer', None)
@ -210,11 +209,7 @@ class Language(object):
warn("using non-package data_dir", DeprecationWarning) warn("using non-package data_dir", DeprecationWarning)
package = Package(data_dir) package = Package(data_dir)
else: else:
if model is None: package = get_package(name=model, data_path=data_dir)
model = '%s_default' % (lang or 'en')
version = None
package = get_package(name=model, version=version,
data_path=data_dir)
if load_vectors is not True: if load_vectors is not True:
warn("load_vectors is deprecated", DeprecationWarning) warn("load_vectors is deprecated", DeprecationWarning)

View File

@ -8,7 +8,7 @@ from sputnik import Sputnik
from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE from .attrs import TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
def get_package(name=None, version=None, data_path=None): def get_package(name=None, data_path=None):
if data_path is None: if data_path is None:
if os.environ.get('SPACY_DATA'): if os.environ.get('SPACY_DATA'):
data_path = os.environ.get('SPACY_DATA') data_path = os.environ.get('SPACY_DATA')
@ -18,12 +18,7 @@ def get_package(name=None, version=None, data_path=None):
sputnik = Sputnik('spacy', '0.100.0') # TODO: retrieve version sputnik = Sputnik('spacy', '0.100.0') # TODO: retrieve version
pool = sputnik.pool(data_path) pool = sputnik.pool(data_path)
return pool.get(name or 'en_default')
if name is None:
name = 'en_default'
if version:
name += ' ==%s' % version
return pool.get(name)
def normalize_slice(length, start, stop, step=None): def normalize_slice(length, start, stop, step=None):