mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
* Allow user to load different sized vectors.
This commit is contained in:
parent
0aed9c9a33
commit
c04e6ebca6
|
@ -42,9 +42,9 @@ cdef class Lexeme:
|
|||
# Workaround for an apparent bug in the way the decorator is handled ---
|
||||
# TODO: post bug report / patch to Cython.
|
||||
@staticmethod
|
||||
cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings):
|
||||
cdef Lexeme py = Lexeme.__new__(Lexeme, 300)
|
||||
for i in range(300):
|
||||
cdef inline Lexeme from_ptr(const LexemeC* ptr, StringStore strings, int repvec_length):
|
||||
cdef Lexeme py = Lexeme.__new__(Lexeme, repvec_length)
|
||||
for i in range(repvec_length):
|
||||
py.repvec[i] = ptr.repvec[i]
|
||||
py.l2_norm = ptr.l2_norm
|
||||
py.flags = ptr.flags
|
||||
|
|
|
@ -464,7 +464,7 @@ cdef class Token:
|
|||
|
||||
property repvec:
|
||||
def __get__(self):
|
||||
return numpy.asarray(<float[:300,]> self.c.lex.repvec)
|
||||
return numpy.asarray(<float[:self.vocab.repvec_length,]> self.c.lex.repvec)
|
||||
|
||||
property n_lefts:
|
||||
def __get__(self):
|
||||
|
|
|
@ -33,3 +33,4 @@ cdef class Vocab:
|
|||
cdef int _add_lex_to_vocab(self, hash_t key, const LexemeC* lex) except -1
|
||||
|
||||
cdef PreshMap _map
|
||||
cdef readonly int repvec_length
|
||||
|
|
|
@ -36,7 +36,7 @@ cdef class Vocab:
|
|||
self.strings = StringStore()
|
||||
self.lexemes.push_back(&EMPTY_LEXEME)
|
||||
self.lexeme_props_getter = get_lex_props
|
||||
|
||||
self.repvec_length = 0
|
||||
if data_dir is not None:
|
||||
if not path.exists(data_dir):
|
||||
raise IOError("Directory %s not found -- cannot load Vocab." % data_dir)
|
||||
|
@ -46,7 +46,7 @@ cdef class Vocab:
|
|||
self.load_lexemes(path.join(data_dir, 'strings.txt'),
|
||||
path.join(data_dir, 'lexemes.bin'))
|
||||
if load_vectors and path.exists(path.join(data_dir, 'vec.bin')):
|
||||
self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
|
||||
self.repvec_length = self.load_rep_vectors(path.join(data_dir, 'vec.bin'))
|
||||
|
||||
def __len__(self):
|
||||
"""The current number of lexemes stored."""
|
||||
|
@ -107,7 +107,7 @@ cdef class Vocab:
|
|||
raise ValueError("Vocab unable to map type: "
|
||||
"%s. Maps unicode --> Lexeme or "
|
||||
"int --> Lexeme" % str(type(id_or_string)))
|
||||
return Lexeme.from_ptr(lexeme, self.strings)
|
||||
return Lexeme.from_ptr(lexeme, self.strings, self.repvec_length)
|
||||
|
||||
def __setitem__(self, unicode py_str, dict props):
|
||||
cdef UniStr c_str
|
||||
|
@ -180,6 +180,7 @@ cdef class Vocab:
|
|||
file_ = _CFile(loc, b'rb')
|
||||
cdef int32_t word_len
|
||||
cdef int32_t vec_len
|
||||
cdef int32_t prev_vec_len = 0
|
||||
cdef float* vec
|
||||
cdef Address mem
|
||||
cdef id_t string_id
|
||||
|
@ -192,7 +193,10 @@ cdef class Vocab:
|
|||
except IOError:
|
||||
break
|
||||
file_.read(&vec_len, sizeof(vec_len), 1)
|
||||
|
||||
if prev_vec_len != 0 and vec_len != prev_vec_len:
|
||||
raise VectorReadError.mismatched_sizes(loc, vec_len, prev_vec_len)
|
||||
if 0 >= vec_len >= MAX_VEC_SIZE:
|
||||
raise VectorReadError.bad_size(loc, vec_len)
|
||||
mem = Address(word_len, sizeof(char))
|
||||
chars = <char*>mem.ptr
|
||||
vec = <float*>self.mem.alloc(vec_len, sizeof(float))
|
||||
|
@ -216,6 +220,7 @@ cdef class Vocab:
|
|||
lex.l2_norm = math.sqrt(lex.l2_norm)
|
||||
else:
|
||||
lex.repvec = EMPTY_VEC
|
||||
return vec_len
|
||||
|
||||
|
||||
def write_binary_vectors(in_loc, out_loc):
|
||||
|
@ -272,3 +277,21 @@ cdef class _CFile:
|
|||
cdef bytes py_bytes = value.encode('utf8')
|
||||
cdef char* chars = <char*>py_bytes
|
||||
self.write(sizeof(char), len(py_bytes), chars)
|
||||
|
||||
|
||||
class VectorReadError(Exception):
|
||||
@classmethod
|
||||
def mismatched_sizes(cls, loc, prev_size, curr_size):
|
||||
return cls(
|
||||
"Error reading word vectors from %s.\n"
|
||||
"All vectors must be the same size.\n"
|
||||
"Prev size: %d\n"
|
||||
"Curr size: %d" % (loc, prev_size, curr_size))
|
||||
|
||||
@classmethod
|
||||
def bad_size(cls, loc, size):
|
||||
return cls(
|
||||
"Error reading word vectors from %s.\n"
|
||||
"Vector size: %d\n"
|
||||
"Max size: %d\n"
|
||||
"Min size: 1\n" % (loc, size, MAX_VEC_SIZE))
|
||||
|
|
Loading…
Reference in New Issue
Block a user