* Fix API for loading word vectors from a file.

This commit is contained in:
Matthew Honnibal 2015-09-23 23:51:08 +10:00
parent 46caf15bca
commit abf0d930af

View File

@ -55,7 +55,7 @@ cdef class Vocab:
self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin')) self.load_lexemes(path.join(data_dir, 'strings.txt'), path.join(data_dir, 'lexemes.bin'))
if path.exists(path.join(data_dir, 'vec.bin')): if path.exists(path.join(data_dir, 'vec.bin')):
self.vectors_length = self.load_vectors(path.join(data_dir, 'vec.bin')) self.vectors_length = self.load_vectors_from_bin_loc(path.join(data_dir, 'vec.bin'))
return self return self
def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None): def __init__(self, get_lex_attr=None, tag_map=None, lemmatizer=None, serializer_freqs=None):
@ -258,35 +258,27 @@ cdef class Vocab:
i += 1 i += 1
fp.close() fp.close()
def load_vectors(self, loc): def load_vectors(self, loc_or_file):
if loc.endswith('bz2'):
vec_len = self.load_vectors_bz2(loc)
else:
vec_len = self.load_vectors_bin(loc)
return vec_len
def load_vectors_bz2(self, loc):
cdef LexemeC* lexeme cdef LexemeC* lexeme
cdef attr_t orth cdef attr_t orth
cdef int32_t vec_len = -1 cdef int32_t vec_len = -1
with bz2.BZ2File(loc, 'r') as file_: for line_num, line in enumerate(loc_or_file):
for line_num, line in enumerate(file_): pieces = line.split()
pieces = line.split() word_str = pieces.pop(0)
word_str = pieces.pop(0) if vec_len == -1:
if vec_len == -1: vec_len = len(pieces)
vec_len = len(pieces) elif vec_len != len(pieces):
elif vec_len != len(pieces): raise VectorReadError.mismatched_sizes(loc_or_file, line_num,
raise VectorReadError.mismatched_sizes(loc, line_num, vec_len, len(pieces))
vec_len, len(pieces)) orth = self.strings[word_str]
orth = self.strings[word_str] lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) lexeme.repvec = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
lexeme.repvec = <float*>self.mem.alloc(self.vectors_length, sizeof(float))
for i, val_str in enumerate(pieces): for i, val_str in enumerate(pieces):
lexeme.repvec[i] = float(val_str) lexeme.repvec[i] = float(val_str)
return vec_len return vec_len
def load_vectors_bin(self, loc): def load_vectors_from_bin_loc(self, loc):
cdef CFile file_ = CFile(loc, b'rb') cdef CFile file_ = CFile(loc, b'rb')
cdef int32_t word_len cdef int32_t word_len
cdef int32_t vec_len = 0 cdef int32_t vec_len = 0