mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Untested fix for issue #684: GloVe vectors hack should be inserted in English, not in spacy.load.
This commit is contained in:
		
							parent
							
								
									c065359459
								
							
						
					
					
						commit
						2ef9d53117
					
				| 
						 | 
				
			
			@ -32,18 +32,6 @@ set_lang_class(nl.Dutch.lang, nl.Dutch)
 | 
			
		|||
def load(name, **overrides):
 | 
			
		||||
    target_name, target_version = util.split_data_name(name)
 | 
			
		||||
    data_path = overrides.get('path', util.get_data_path())
 | 
			
		||||
    if target_name == 'en' and 'add_vectors' not in overrides:
 | 
			
		||||
        if 'vectors' in overrides:
 | 
			
		||||
            vec_path = util.match_best_version(overrides['vectors'], None, data_path)
 | 
			
		||||
            if vec_path is None:
 | 
			
		||||
                raise IOError(
 | 
			
		||||
                    'Could not load data pack %s from %s' % (overrides['vectors'], data_path))
 | 
			
		||||
 | 
			
		||||
        else:
 | 
			
		||||
            vec_path = util.match_best_version('en_glove_cc_300_1m_vectors', None, data_path)
 | 
			
		||||
        if vec_path is not None:
 | 
			
		||||
            vec_path = vec_path / 'vocab' / 'vec.bin'
 | 
			
		||||
            overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
 | 
			
		||||
    path = util.match_best_version(target_name, target_version, data_path)
 | 
			
		||||
    cls = get_lang_class(target_name)
 | 
			
		||||
    overrides['path'] = path
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,6 +3,7 @@ from __future__ import unicode_literals, print_function
 | 
			
		|||
 | 
			
		||||
from os import path
 | 
			
		||||
 | 
			
		||||
from ..util import match_best_version
 | 
			
		||||
from ..language import Language
 | 
			
		||||
from ..lemmatizer import Lemmatizer
 | 
			
		||||
from ..vocab import Vocab
 | 
			
		||||
| 
						 | 
				
			
			@ -23,3 +24,31 @@ class English(Language):
 | 
			
		|||
        tag_map = TAG_MAP
 | 
			
		||||
        stop_words = STOP_WORDS
 | 
			
		||||
        lemma_rules = LEMMA_RULES
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    def __init__(self, **overrides):
 | 
			
		||||
        # Make a special-case hack for loading the GloVe vectors, to support
 | 
			
		||||
        # deprecated <1.0 stuff. Phase this out once the data is fixed.
 | 
			
		||||
        overrides = _fix_deprecated_glove_vectors_loading(overrides)
 | 
			
		||||
        Language.__init__(self, **overrides)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _fix_deprecated_glove_vectors_loading(overrides):
 | 
			
		||||
    if 'data_dir' in overrides and 'path' not in overrides:
 | 
			
		||||
        raise ValueError("The argument 'data_dir' has been renamed to 'path'")
 | 
			
		||||
    if overrides.get('path') is None:
 | 
			
		||||
        return overrides
 | 
			
		||||
    path = overrides['path']
 | 
			
		||||
    if 'add_vectors' not in overrides:
 | 
			
		||||
        data_path = path.parent
 | 
			
		||||
        if 'vectors' in overrides:
 | 
			
		||||
            vec_path = match_best_version(overrides['vectors'], None, data_path)
 | 
			
		||||
            if vec_path is None:
 | 
			
		||||
                raise IOError(
 | 
			
		||||
                    'Could not load data pack %s from %s' % (overrides['vectors'], data_path))
 | 
			
		||||
        else:
 | 
			
		||||
            vec_path = match_best_version('en_glove_cc_300_1m_vectors', None, data_path)
 | 
			
		||||
        if vec_path is not None:
 | 
			
		||||
            vec_path = vec_path / 'vocab' / 'vec.bin'
 | 
			
		||||
    overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
 | 
			
		||||
    return overrides
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user