diff --git a/spacy/__init__.py b/spacy/__init__.py index 8bb1c2129..05e732a50 100644 --- a/spacy/__init__.py +++ b/spacy/__init__.py @@ -32,18 +32,6 @@ set_lang_class(nl.Dutch.lang, nl.Dutch) def load(name, **overrides): target_name, target_version = util.split_data_name(name) data_path = overrides.get('path', util.get_data_path()) - if target_name == 'en' and 'add_vectors' not in overrides: - if 'vectors' in overrides: - vec_path = util.match_best_version(overrides['vectors'], None, data_path) - if vec_path is None: - raise IOError( - 'Could not load data pack %s from %s' % (overrides['vectors'], data_path)) - - else: - vec_path = util.match_best_version('en_glove_cc_300_1m_vectors', None, data_path) - if vec_path is not None: - vec_path = vec_path / 'vocab' / 'vec.bin' - overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path) path = util.match_best_version(target_name, target_version, data_path) cls = get_lang_class(target_name) overrides['path'] = path diff --git a/spacy/en/__init__.py b/spacy/en/__init__.py index 24506c145..55832543d 100644 --- a/spacy/en/__init__.py +++ b/spacy/en/__init__.py @@ -3,6 +3,7 @@ from __future__ import unicode_literals, print_function from os import path +from ..util import match_best_version from ..language import Language from ..lemmatizer import Lemmatizer from ..vocab import Vocab @@ -23,3 +24,31 @@ class English(Language): tag_map = TAG_MAP stop_words = STOP_WORDS lemma_rules = LEMMA_RULES + + + def __init__(self, **overrides): + # Make a special-case hack for loading the GloVe vectors, to support + # deprecated <1.0 stuff. Phase this out once the data is fixed. + overrides = _fix_deprecated_glove_vectors_loading(overrides) + Language.__init__(self, **overrides) + + +def _fix_deprecated_glove_vectors_loading(overrides): + if 'data_dir' in overrides and 'path' not in overrides: + raise ValueError("The argument 'data_dir' has been renamed to 'path'") + if overrides.get('path') is None: + return overrides + path = overrides['path'] + if 'add_vectors' not in overrides: + data_path = path.parent + if 'vectors' in overrides: + vec_path = match_best_version(overrides['vectors'], None, data_path) + if vec_path is None: + raise IOError( + 'Could not load data pack %s from %s' % (overrides['vectors'], data_path)) + else: + vec_path = match_best_version('en_glove_cc_300_1m_vectors', None, data_path) + if vec_path is not None: + vec_path = vec_path / 'vocab' / 'vec.bin' + overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path) + return overrides