mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Fix loading of GloVe vectors, to address Issue #541
This commit is contained in:
parent
06b83d8f40
commit
5ec32f5d97
|
@ -13,7 +13,6 @@ except NameError:
|
||||||
basestring = str
|
basestring = str
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
set_lang_class(en.English.lang, en.English)
|
set_lang_class(en.English.lang, en.English)
|
||||||
set_lang_class(de.German.lang, de.German)
|
set_lang_class(de.German.lang, de.German)
|
||||||
set_lang_class(zh.Chinese.lang, zh.Chinese)
|
set_lang_class(zh.Chinese.lang, zh.Chinese)
|
||||||
|
@ -21,13 +20,19 @@ set_lang_class(zh.Chinese.lang, zh.Chinese)
|
||||||
|
|
||||||
def load(name, **overrides):
|
def load(name, **overrides):
|
||||||
target_name, target_version = util.split_data_name(name)
|
target_name, target_version = util.split_data_name(name)
|
||||||
path = overrides.get('path', util.get_data_path())
|
data_path = overrides.get('path', util.get_data_path())
|
||||||
path = util.match_best_version(target_name, target_version, path)
|
if target_name == 'en' and 'add_vectors' not in overrides:
|
||||||
|
if 'vectors' in overrides:
|
||||||
|
vec_path = util.match_best_version(overrides['vectors'], None, data_path)
|
||||||
|
if vec_path is None:
|
||||||
|
raise IOError(
|
||||||
|
'Could not load data pack %s from %s' % (overrides['vectors'], data_path))
|
||||||
|
|
||||||
if isinstance(overrides.get('vectors'), basestring):
|
else:
|
||||||
vectors_path = util.match_best_version(overrides.get('vectors'), None, path)
|
vec_path = util.match_best_version('en_glove_cc_300_1m_vectors', None, data_path)
|
||||||
overrides['vectors'] = lambda nlp: nlp.vocab.load_vectors_from_bin_loc(
|
if vec_path is not None:
|
||||||
vectors_path / 'vocab' / 'vec.bin')
|
vec_path = vec_path / 'vocab' / 'vec.bin'
|
||||||
|
overrides['add_vectors'] = lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
|
||||||
|
path = util.match_best_version(target_name, target_version, data_path)
|
||||||
cls = get_lang_class(target_name)
|
cls = get_lang_class(target_name)
|
||||||
return cls(path=path, **overrides)
|
return cls(path=path, **overrides)
|
||||||
|
|
|
@ -53,7 +53,11 @@ class BaseDefaults(object):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def add_vectors(cls, nlp=None):
|
def add_vectors(cls, nlp=None):
|
||||||
return True
|
if nlp is None or nlp.path is None:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
vec_path = nlp.path / 'vocab' / 'vec.bin'
|
||||||
|
return lambda vocab: vocab.load_vectors_from_bin_loc(vec_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create_tokenizer(cls, nlp=None):
|
def create_tokenizer(cls, nlp=None):
|
||||||
|
@ -246,6 +250,11 @@ class Language(object):
|
||||||
self.vocab = self.Defaults.create_vocab(self) \
|
self.vocab = self.Defaults.create_vocab(self) \
|
||||||
if 'vocab' not in overrides \
|
if 'vocab' not in overrides \
|
||||||
else overrides['vocab']
|
else overrides['vocab']
|
||||||
|
add_vectors = self.Defaults.add_vectors(self) \
|
||||||
|
if 'add_vectors' not in overrides \
|
||||||
|
else overrides['add_vectors']
|
||||||
|
if add_vectors:
|
||||||
|
add_vectors(self.vocab)
|
||||||
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
||||||
if 'tokenizer' not in overrides \
|
if 'tokenizer' not in overrides \
|
||||||
else overrides['tokenizer']
|
else overrides['tokenizer']
|
||||||
|
|
|
@ -49,9 +49,13 @@ cdef class Vocab:
|
||||||
'''A map container for a language's LexemeC structs.
|
'''A map container for a language's LexemeC structs.
|
||||||
'''
|
'''
|
||||||
@classmethod
|
@classmethod
|
||||||
def load(cls, path, lex_attr_getters=None, vectors=True, lemmatizer=True,
|
def load(cls, path, lex_attr_getters=None, lemmatizer=True,
|
||||||
tag_map=True, serializer_freqs=True, oov_prob=True, **deprecated_kwargs):
|
tag_map=True, serializer_freqs=True, oov_prob=True, **deprecated_kwargs):
|
||||||
util.check_renamed_kwargs({'get_lex_attr': 'lex_attr_getters'}, deprecated_kwargs)
|
util.check_renamed_kwargs({'get_lex_attr': 'lex_attr_getters'}, deprecated_kwargs)
|
||||||
|
if 'vectors' in deprecated_kwargs:
|
||||||
|
raise AttributeError(
|
||||||
|
"vectors argument to Vocab.load() deprecated. "
|
||||||
|
"Install vectors after loading.")
|
||||||
if tag_map is True and (path / 'vocab' / 'tag_map.json').exists():
|
if tag_map is True and (path / 'vocab' / 'tag_map.json').exists():
|
||||||
with (path / 'vocab' / 'tag_map.json').open() as file_:
|
with (path / 'vocab' / 'tag_map.json').open() as file_:
|
||||||
tag_map = json.load(file_)
|
tag_map = json.load(file_)
|
||||||
|
@ -73,15 +77,6 @@ cdef class Vocab:
|
||||||
with (path / 'vocab' / 'strings.json').open() as file_:
|
with (path / 'vocab' / 'strings.json').open() as file_:
|
||||||
self.strings.load(file_)
|
self.strings.load(file_)
|
||||||
self.load_lexemes(path / 'vocab' / 'lexemes.bin')
|
self.load_lexemes(path / 'vocab' / 'lexemes.bin')
|
||||||
|
|
||||||
if vectors is True:
|
|
||||||
vec_path = path / 'vocab' / 'vec.bin'
|
|
||||||
if vec_path.exists():
|
|
||||||
vectors = lambda self_: self_.load_vectors_from_bin_loc(vec_path)
|
|
||||||
else:
|
|
||||||
vectors = lambda self_: 0
|
|
||||||
if vectors:
|
|
||||||
self.vectors_length = vectors(self)
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None,
|
def __init__(self, lex_attr_getters=None, tag_map=None, lemmatizer=None,
|
||||||
|
@ -387,10 +382,11 @@ cdef class Vocab:
|
||||||
vec_len, len(pieces))
|
vec_len, len(pieces))
|
||||||
orth = self.strings[word_str]
|
orth = self.strings[word_str]
|
||||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||||
lexeme.vector = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
|
lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
|
||||||
|
|
||||||
for i, val_str in enumerate(pieces):
|
for i, val_str in enumerate(pieces):
|
||||||
lexeme.vector[i] = float(val_str)
|
lexeme.vector[i] = float(val_str)
|
||||||
|
self.vectors_length = vec_len
|
||||||
return vec_len
|
return vec_len
|
||||||
|
|
||||||
def load_vectors_from_bin_loc(self, loc):
|
def load_vectors_from_bin_loc(self, loc):
|
||||||
|
@ -438,6 +434,7 @@ cdef class Vocab:
|
||||||
lex.l2_norm = math.sqrt(lex.l2_norm)
|
lex.l2_norm = math.sqrt(lex.l2_norm)
|
||||||
else:
|
else:
|
||||||
lex.vector = EMPTY_VEC
|
lex.vector = EMPTY_VEC
|
||||||
|
self.vectors_length = vec_len
|
||||||
return vec_len
|
return vec_len
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user