Remove vectors from vocab

This commit is contained in:
Matthew Honnibal 2017-05-28 11:45:32 +02:00
parent 48eef94f92
commit 15f6efc127

View File

@ -26,15 +26,6 @@ from . import attrs
from . import symbols
DEF MAX_VEC_SIZE = 100000
cdef float[MAX_VEC_SIZE] EMPTY_VEC
memset(EMPTY_VEC, 0, sizeof(EMPTY_VEC))
memset(&EMPTY_LEXEME, 0, sizeof(LexemeC))
EMPTY_LEXEME.vector = EMPTY_VEC
cdef class Vocab:
"""A look-up table that allows you to access `Lexeme` objects. The `Vocab`
instance also provides access to the `StringStore`, and owns underlying
@ -179,7 +170,6 @@ cdef class Vocab:
lex.orth = self.strings[string]
lex.length = len(string)
lex.id = self.length
lex.vector = <float*>mem.alloc(self.vectors_length, sizeof(float))
if self.lex_attr_getters is not None:
for attr, func in self.lex_attr_getters.items():
value = func(string)
@ -258,6 +248,26 @@ cdef class Vocab:
Token.set_struct_attr(token, attr_id, value)
return tokens
def get_vector(self, orth):
"""Retrieve a vector for a word in the vocabulary.
Words can be looked up by string or int ID.
RETURNS:
A word vector. Size and shape determed by the
vocab.vectors instance. Usually, a numpy ndarray
of shape (300,) and dtype float32.
RAISES: If no vectors data is loaded, ValueError is raised.
"""
raise NotImplementedError
def has_vector(self, orth):
"""Check whether a word has a vector. Returns False if no
vectors have been loaded. Words can be looked up by string
or int ID."""
raise NotImplementedError
def to_disk(self, path):
"""Save the current state to a directory.
@ -271,9 +281,6 @@ cdef class Vocab:
with strings_loc.open('w', encoding='utf8') as file_:
self.strings.dump(file_)
# TODO: pickle
# self.dump(path / 'lexemes.bin')
def from_disk(self, path):
"""Loads state from a directory. Modifies the object in place and
returns it.
@ -346,7 +353,6 @@ cdef class Vocab:
lex_data.data[j] = bytes_ptr[i+j]
Lexeme.c_from_bytes(lexeme, lex_data)
lexeme.vector = EMPTY_VEC
py_str = self.strings[lexeme.orth]
assert self.strings[py_str] == lexeme.orth, (py_str, lexeme.orth)
key = hash_string(py_str)
@ -354,172 +360,6 @@ cdef class Vocab:
self._by_orth.set(lexeme.orth, lexeme)
self.length += 1
# Deprecated --- delete these once stable
def dump_vectors(self, out_loc):
"""Save the word vectors to a binary file.
loc (Path): The path to save to.
"""
cdef int32_t vec_len = self.vectors_length
cdef int32_t word_len
cdef bytes word_str
cdef char* chars
cdef Lexeme lexeme
cdef CFile out_file = CFile(out_loc, 'wb')
for lexeme in self:
word_str = lexeme.orth_.encode('utf8')
vec = lexeme.c.vector
word_len = len(word_str)
out_file.write_from(&word_len, 1, sizeof(word_len))
out_file.write_from(&vec_len, 1, sizeof(vec_len))
chars = <char*>word_str
out_file.write_from(chars, word_len, sizeof(char))
out_file.write_from(vec, vec_len, sizeof(float))
out_file.close()
def load_vectors(self, file_):
"""Load vectors from a text-based file.
file_ (buffer): The file to read from. Entries should be separated by
newlines, and each entry should be whitespace delimited. The first value of the entry
should be the word string, and subsequent entries should be the values of the
vector.
RETURNS (int): The length of the vectors loaded.
"""
cdef LexemeC* lexeme
cdef attr_t orth
cdef int32_t vec_len = -1
cdef double norm = 0.0
whitespace_pattern = re.compile(r'\s', re.UNICODE)
for line_num, line in enumerate(file_):
pieces = line.split()
word_str = " " if whitespace_pattern.match(line) else pieces.pop(0)
if vec_len == -1:
vec_len = len(pieces)
elif vec_len != len(pieces):
raise VectorReadError.mismatched_sizes(file_, line_num,
vec_len, len(pieces))
orth = self.strings[word_str]
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
lexeme.vector = <float*>self.mem.alloc(vec_len, sizeof(float))
for i, val_str in enumerate(pieces):
lexeme.vector[i] = float(val_str)
norm = 0.0
for i in range(vec_len):
norm += lexeme.vector[i] * lexeme.vector[i]
lexeme.l2_norm = sqrt(norm)
self.vectors_length = vec_len
return vec_len
def load_vectors_from_bin_loc(self, loc):
"""Load vectors from the location of a binary file.
loc (unicode): The path of the binary file to load from.
RETURNS (int): The length of the vectors loaded.
"""
cdef CFile file_ = CFile(loc, b'rb')
cdef int32_t word_len
cdef int32_t vec_len = 0
cdef int32_t prev_vec_len = 0
cdef float* vec
cdef Address mem
cdef attr_t string_id
cdef bytes py_word
cdef vector[float*] vectors
cdef int line_num = 0
cdef Pool tmp_mem = Pool()
while True:
try:
file_.read_into(&word_len, sizeof(word_len), 1)
except IOError:
break
file_.read_into(&vec_len, sizeof(vec_len), 1)
if prev_vec_len != 0 and vec_len != prev_vec_len:
raise VectorReadError.mismatched_sizes(loc, line_num,
vec_len, prev_vec_len)
if 0 >= vec_len >= MAX_VEC_SIZE:
raise VectorReadError.bad_size(loc, vec_len)
chars = <char*>file_.alloc_read(tmp_mem, word_len, sizeof(char))
vec = <float*>file_.alloc_read(self.mem, vec_len, sizeof(float))
string_id = self.strings[chars[:word_len]]
# Insert words into vocab to add vector.
self.get_by_orth(self.mem, string_id)
while string_id >= vectors.size():
vectors.push_back(EMPTY_VEC)
assert vec != NULL
vectors[string_id] = vec
line_num += 1
cdef LexemeC* lex
cdef size_t lex_addr
cdef double norm = 0.0
cdef int i
for orth, lex_addr in self._by_orth.items():
lex = <LexemeC*>lex_addr
if lex.lower < vectors.size():
lex.vector = vectors[lex.lower]
norm = 0.0
for i in range(vec_len):
norm += lex.vector[i] * lex.vector[i]
lex.l2_norm = sqrt(norm)
else:
lex.vector = EMPTY_VEC
self.vectors_length = vec_len
return vec_len
def resize_vectors(self, int new_size):
"""Set vectors_length to a new size, and allocate more memory for the
`Lexeme` vectors if necessary. The memory will be zeroed.
new_size (int): The new size of the vectors.
"""
cdef hash_t key
cdef size_t addr
if new_size > self.vectors_length:
for key, addr in self._by_hash.items():
lex = <LexemeC*>addr
lex.vector = <float*>self.mem.realloc(lex.vector,
new_size * sizeof(lex.vector[0]))
self.vectors_length = new_size
def write_binary_vectors(in_loc, out_loc):
cdef CFile out_file = CFile(out_loc, 'wb')
cdef Address mem
cdef int32_t word_len
cdef int32_t vec_len
cdef char* chars
with bz2.BZ2File(in_loc, 'r') as file_:
for line in file_:
pieces = line.split()
word = pieces.pop(0)
mem = Address(len(pieces), sizeof(float))
vec = <float*>mem.ptr
for i, val_str in enumerate(pieces):
vec[i] = float(val_str)
word_len = len(word)
vec_len = len(pieces)
out_file.write_from(&word_len, 1, sizeof(word_len))
out_file.write_from(&vec_len, 1, sizeof(vec_len))
chars = <char*>word
out_file.write_from(chars, len(word), sizeof(char))
out_file.write_from(vec, vec_len, sizeof(float))
def pickle_vocab(vocab):
sstore = vocab.strings
@ -567,21 +407,3 @@ class LookupError(Exception):
"ID of orth: {orth_id}".format(
query=repr(original_string), orth_str=repr(id_string), orth_id=id_)
)
class VectorReadError(Exception):
@classmethod
def mismatched_sizes(cls, loc, line_num, prev_size, curr_size):
return cls(
"Error reading word vectors from %s on line %d.\n"
"All vectors must be the same size.\n"
"Prev size: %d\n"
"Curr size: %d" % (loc, line_num, prev_size, curr_size))
@classmethod
def bad_size(cls, loc, size):
return cls(
"Error reading word vectors from %s.\n"
"Vector size: %d\n"
"Max size: %d\n"
"Min size: 1\n" % (loc, size, MAX_VEC_SIZE))