mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
* Fix vector length error reporting, and ensure vec_len is returned
This commit is contained in:
parent
ba4e563701
commit
ac459278d1
|
@ -263,9 +263,10 @@ cdef class Vocab:
|
||||||
|
|
||||||
def load_vectors(self, loc):
|
def load_vectors(self, loc):
|
||||||
if loc.endswith('bz2'):
|
if loc.endswith('bz2'):
|
||||||
self.load_vectors_bz2(loc)
|
vec_len = self.load_vectors_bz2(loc)
|
||||||
else:
|
else:
|
||||||
self.load_vectors_bin(loc)
|
vec_len = self.load_vectors_bin(loc)
|
||||||
|
return vec_len
|
||||||
|
|
||||||
def load_vectors_bz2(self, loc):
|
def load_vectors_bz2(self, loc):
|
||||||
cdef LexemeC* lexeme
|
cdef LexemeC* lexeme
|
||||||
|
@ -278,10 +279,8 @@ cdef class Vocab:
|
||||||
if vec_len == -1:
|
if vec_len == -1:
|
||||||
vec_len = len(pieces)
|
vec_len = len(pieces)
|
||||||
elif vec_len != len(pieces):
|
elif vec_len != len(pieces):
|
||||||
raise IOError(
|
raise VectorReadError.mismatched_sizes(loc, line_num,
|
||||||
"Error loading word vectors: all vectors must be same "
|
vec_len, len(pieces))
|
||||||
"length. Previous vector was length %d, vector on line %d "
|
|
||||||
"was length %d." % (vec_len, line_num, len(pieces)))
|
|
||||||
orth = self.strings[word_str]
|
orth = self.strings[word_str]
|
||||||
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
|
||||||
lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float))
|
lexeme.repvec = <float*>self.mem.alloc(len(pieces), sizeof(float))
|
||||||
|
@ -293,14 +292,14 @@ cdef class Vocab:
|
||||||
def load_vectors_bin(self, loc):
|
def load_vectors_bin(self, loc):
|
||||||
cdef CFile file_ = CFile(loc, b'rb')
|
cdef CFile file_ = CFile(loc, b'rb')
|
||||||
cdef int32_t word_len
|
cdef int32_t word_len
|
||||||
cdef int32_t vec_len
|
cdef int32_t vec_len = 0
|
||||||
cdef int32_t prev_vec_len = 0
|
cdef int32_t prev_vec_len = 0
|
||||||
cdef float* vec
|
cdef float* vec
|
||||||
cdef Address mem
|
cdef Address mem
|
||||||
cdef attr_t string_id
|
cdef attr_t string_id
|
||||||
cdef bytes py_word
|
cdef bytes py_word
|
||||||
cdef vector[float*] vectors
|
cdef vector[float*] vectors
|
||||||
cdef int i
|
cdef int line_num = 0
|
||||||
cdef Pool tmp_mem = Pool()
|
cdef Pool tmp_mem = Pool()
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -309,7 +308,8 @@ cdef class Vocab:
|
||||||
break
|
break
|
||||||
file_.read_into(&vec_len, sizeof(vec_len), 1)
|
file_.read_into(&vec_len, sizeof(vec_len), 1)
|
||||||
if prev_vec_len != 0 and vec_len != prev_vec_len:
|
if prev_vec_len != 0 and vec_len != prev_vec_len:
|
||||||
raise VectorReadError.mismatched_sizes(loc, vec_len, prev_vec_len)
|
raise VectorReadError.mismatched_sizes(loc, line_num,
|
||||||
|
vec_len, prev_vec_len)
|
||||||
if 0 >= vec_len >= MAX_VEC_SIZE:
|
if 0 >= vec_len >= MAX_VEC_SIZE:
|
||||||
raise VectorReadError.bad_size(loc, vec_len)
|
raise VectorReadError.bad_size(loc, vec_len)
|
||||||
|
|
||||||
|
@ -321,8 +321,10 @@ cdef class Vocab:
|
||||||
vectors.push_back(EMPTY_VEC)
|
vectors.push_back(EMPTY_VEC)
|
||||||
assert vec != NULL
|
assert vec != NULL
|
||||||
vectors[string_id] = vec
|
vectors[string_id] = vec
|
||||||
|
line_num += 1
|
||||||
cdef LexemeC* lex
|
cdef LexemeC* lex
|
||||||
cdef size_t lex_addr
|
cdef size_t lex_addr
|
||||||
|
cdef int i
|
||||||
for orth, lex_addr in self._by_orth.items():
|
for orth, lex_addr in self._by_orth.items():
|
||||||
lex = <LexemeC*>lex_addr
|
lex = <LexemeC*>lex_addr
|
||||||
if lex.lower < vectors.size():
|
if lex.lower < vectors.size():
|
||||||
|
@ -363,12 +365,12 @@ def write_binary_vectors(in_loc, out_loc):
|
||||||
|
|
||||||
class VectorReadError(Exception):
|
class VectorReadError(Exception):
|
||||||
@classmethod
|
@classmethod
|
||||||
def mismatched_sizes(cls, loc, prev_size, curr_size):
|
def mismatched_sizes(cls, loc, line_num, prev_size, curr_size):
|
||||||
return cls(
|
return cls(
|
||||||
"Error reading word vectors from %s.\n"
|
"Error reading word vectors from %s on line %d.\n"
|
||||||
"All vectors must be the same size.\n"
|
"All vectors must be the same size.\n"
|
||||||
"Prev size: %d\n"
|
"Prev size: %d\n"
|
||||||
"Curr size: %d" % (loc, prev_size, curr_size))
|
"Curr size: %d" % (loc, line_num, prev_size, curr_size))
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def bad_size(cls, loc, size):
|
def bad_size(cls, loc, size):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user