mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
* Fix API for loading word vectors from a file.
This commit is contained in:
parent
46caf15bca
commit
abf0d930af
|
@ -55,7 +55,7 @@ cdef class Vocab:
|
||||||
|
|
||||||
self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin'))
|
self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin'))
|
||||||
if path.exists(path.join(data_dir, 'vec.bin')):
|
if path.exists(path.join(data_dir, 'vec.bin')):
|
||||||
self.vectors_length = self.load_vectors(path.join(data_dir, 'vec.bin'))
|
self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin'))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None):
|
def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None):
|
||||||
|
@ -258,35 +258,27 @@ cdef class Vocab:
|
||||||
i += 1
|
i += 1
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
def load_vectors(self, loc):
|
def load_vectors(self, loc_or_file):
|
||||||
if loc.endswith('bz2'):
|
|
||||||
vec_len = self.load_vectors_bz2(loc)
|
|
||||||
else:
|
|
||||||
vec_len = self.load_vectors_bin(loc)
|
|
||||||
return vec_len
|
|
||||||
|
|
||||||
def load_vectors_bz2(self, loc):
|
|
||||||
cdef LexemeC* lexeme
|
cdef LexemeC* lexeme
|
||||||
cdef attr_t orth
|
cdef attr_t orth
|
||||||
cdef int32_t vec_len = -1
|
cdef int32_t vec_len = -1
|
||||||
with bz2.BZ2File(loc, 'r') as file_:
|
for line_num, line in enumerate(loc_or_file):
|
||||||
for line_num, line in enumerate(file_):
|
pieces = line.split()
|
||||||
pieces = line.split()
|
word_str = pieces.pop(0)
|
||||||
word_str = pieces.pop(0)
|
if vec_len == -1:
|
||||||
if vec_len == -1:
|
vec_len = len(pieces)
|
||||||
vec_len = len(pieces)
|
elif vec_len != len(pieces):
|
||||||
elif vec_len != len(pieces):
|
raise VectorReadError.mismatched_sizes(loc_or_file, line_num,
|
||||||
raise VectorReadError.mismatched_sizes(loc, line_num,
|
vec_len, len(pieces))
|
||||||
vec_len, len(pieces))
|
orth = self.strings[word_str]
|
||||||
orth = self.strings[word_str]
|
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
lexeme.repvec = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
|
||||||
lexeme.repvec = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
|
|
||||||
|
|
||||||
for i, val_str in enumerate(pieces):
|
for i, val_str in enumerate(pieces):
|
||||||
lexeme.repvec[i] = float(val_str)
|
lexeme.repvec[i] = float(val_str)
|
||||||
return vec_len
|
return vec_len
|
||||||
|
|
||||||
def load_vectors_bin(self, loc):
|
def load_vectors_from_bin_loc(self, loc):
|
||||||
cdef CFile file_ = CFile(loc, b'rb')
|
cdef CFile file_ = CFile(loc, b'rb')
|
||||||
cdef int32_t word_len
|
cdef int32_t word_len
|
||||||
cdef int32_t vec_len = 0
|
cdef int32_t vec_len = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user