Fix vector pruning

This commit is contained in:
Matthew Honnibal 2017-10-30 19:44:40 +01:00
parent e98451b5f7
commit 4112a991ec
2 changed files with 27 additions and 16 deletions

View File

@ -30,7 +30,8 @@ cdef class Vectors:
cdef readonly StringStore strings cdef readonly StringStore strings
cdef public object key2row cdef public object key2row
cdef public object keys cdef public object keys
cdef public int i cdef public int _i_key
cdef public int _i_vec
def __init__(self, strings, width=0, data=None): def __init__(self, strings, width=0, data=None):
"""Create a new vector store. To keep the vector table empty, pass """Create a new vector store. To keep the vector table empty, pass
@ -53,7 +54,8 @@ cdef class Vectors:
self.data = numpy.asarray(data, dtype='f') self.data = numpy.asarray(data, dtype='f')
else: else:
self.data = numpy.zeros((len(self.strings), width), dtype='f') self.data = numpy.zeros((len(self.strings), width), dtype='f')
self.i = 0 self._i_key = 0
self._i_vec = 0
self.key2row = {} self.key2row = {}
self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64') self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64')
if data is not None: if data is not None:
@ -105,7 +107,7 @@ cdef class Vectors:
RETURNS (int): The number of vectors in the data. RETURNS (int): The number of vectors in the data.
""" """
return self.i return self._i_vec
def __contains__(self, key): def __contains__(self, key):
"""Check whether a key has a vector entry in the table. """Check whether a key has a vector entry in the table.
@ -127,20 +129,20 @@ cdef class Vectors:
""" """
if isinstance(key, basestring_): if isinstance(key, basestring_):
key = self.strings.add(key) key = self.strings.add(key)
if key in self.key2row and row is None: if row is None and key in self.key2row:
row = self.key2row[key] row = self.key2row[key]
elif key in self.key2row and row is not None:
self.key2row[key] = row
elif row is None: elif row is None:
row = self.i row = self._i_vec
self.i += 1 self._i_vec += 1
if row >= self.keys.shape[0]: if row >= self.data.shape[0]:
self.keys.resize((row*2,))
self.data.resize((row*2, self.data.shape[1])) self.data.resize((row*2, self.data.shape[1]))
self.keys[row] = key if key not in self.key2row:
if self._i_key >= self.keys.shape[0]:
self.keys.resize((self._i_key*2,))
self.keys[self._i_key] = key
self._i_key += 1
self.key2row[key] = row self.key2row[key] = row
self.keys[row] = key
if vector is not None: if vector is not None:
self.data[row] = vector self.data[row] = vector
return row return row

View File

@ -248,7 +248,7 @@ cdef class Vocab:
width = self.vectors.data.shape[1] width = self.vectors.data.shape[1]
self.vectors = Vectors(self.strings, width=width) self.vectors = Vectors(self.strings, width=width)
def prune_vectors(self, nr_row, batch_size=1024): def prune_vectors(self, nr_row, batch_size=8):
"""Reduce the current vector table to `nr_row` unique entries. Words """Reduce the current vector table to `nr_row` unique entries. Words
mapped to the discarded vectors will be remapped to the closest vector mapped to the discarded vectors will be remapped to the closest vector
among those remaining. among those remaining.
@ -267,22 +267,31 @@ cdef class Vocab:
xp = get_array_module(self.vectors.data) xp = get_array_module(self.vectors.data)
# Work in batches, to avoid memory problems. # Work in batches, to avoid memory problems.
keep = self.vectors.data[:nr_row] keep = self.vectors.data[:nr_row]
keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row]
toss = self.vectors.data[nr_row:] toss = self.vectors.data[nr_row:]
# Normalize the vectors, so cosine similarity is just dot product. # Normalize the vectors, so cosine similarity is just dot product.
# Note we can't modify the ones we're keeping in-place... # Note we can't modify the ones we're keeping in-place...
keep = keep / (xp.linalg.norm(keep)+1e-8) keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8)
keep = xp.ascontiguousarray(keep.T) keep = xp.ascontiguousarray(keep.T)
neighbours = xp.zeros((toss.shape[0],), dtype='i') neighbours = xp.zeros((toss.shape[0],), dtype='i')
scores = xp.zeros((toss.shape[0],), dtype='f')
for i in range(0, toss.shape[0], batch_size): for i in range(0, toss.shape[0], batch_size):
batch = toss[i : i+batch_size] batch = toss[i : i+batch_size]
batch /= xp.linalg.norm(batch)+1e-8 batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8
neighbours[i:i+batch_size] = xp.dot(batch, keep).argmax(axis=1) sims = xp.dot(batch, keep)
matches = sims.argmax(axis=1)
neighbours[i:i+batch_size] = matches
scores[i:i+batch_size] = sims.max(axis=1)
for lex in self: for lex in self:
# If we're losing the vector for this word, map it to the nearest # If we're losing the vector for this word, map it to the nearest
# vector we're keeping. # vector we're keeping.
if lex.rank >= nr_row: if lex.rank >= nr_row:
lex.rank = neighbours[lex.rank-nr_row] lex.rank = neighbours[lex.rank-nr_row]
self.vectors.add(lex.orth, row=lex.rank) self.vectors.add(lex.orth, row=lex.rank)
for key in self.vectors.keys:
row = self.vectors.key2row[key]
if row >= nr_row:
self.vectors.key2row[key] = neighbours[row-nr_row]
# Make copy, to encourage the original table to be garbage collected. # Make copy, to encourage the original table to be garbage collected.
self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row]) self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])