From 4112a991ec012b175a1a97add51ce04d09351886 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 30 Oct 2017 19:44:40 +0100 Subject: [PATCH] Fix vector pruning --- spacy/vectors.pyx | 26 ++++++++++++++------------ spacy/vocab.pyx | 17 +++++++++++++---- 2 files changed, 27 insertions(+), 16 deletions(-) diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index 368b73866..552a6bcf3 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -30,7 +30,8 @@ cdef class Vectors: cdef readonly StringStore strings cdef public object key2row cdef public object keys - cdef public int i + cdef public int _i_key + cdef public int _i_vec def __init__(self, strings, width=0, data=None): """Create a new vector store. To keep the vector table empty, pass @@ -53,7 +54,8 @@ cdef class Vectors: self.data = numpy.asarray(data, dtype='f') else: self.data = numpy.zeros((len(self.strings), width), dtype='f') - self.i = 0 + self._i_key = 0 + self._i_vec = 0 self.key2row = {} self.keys = numpy.zeros((self.data.shape[0],), dtype='uint64') if data is not None: @@ -105,7 +107,7 @@ cdef class Vectors: RETURNS (int): The number of vectors in the data. """ - return self.i + return self._i_vec def __contains__(self, key): """Check whether a key has a vector entry in the table. @@ -127,20 +129,20 @@ cdef class Vectors: """ if isinstance(key, basestring_): key = self.strings.add(key) - if key in self.key2row and row is None: + if row is None and key in self.key2row: row = self.key2row[key] - elif key in self.key2row and row is not None: - self.key2row[key] = row elif row is None: - row = self.i - self.i += 1 - if row >= self.keys.shape[0]: - self.keys.resize((row*2,)) + row = self._i_vec + self._i_vec += 1 + if row >= self.data.shape[0]: self.data.resize((row*2, self.data.shape[1])) - self.keys[row] = key + if key not in self.key2row: + if self._i_key >= self.keys.shape[0]: + self.keys.resize((self._i_key*2,)) + self.keys[self._i_key] = key + self._i_key += 1 self.key2row[key] = row - self.keys[row] = key if vector is not None: self.data[row] = vector return row diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index ff6c5b844..ecf1ad9d9 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -248,7 +248,7 @@ cdef class Vocab: width = self.vectors.data.shape[1] self.vectors = Vectors(self.strings, width=width) - def prune_vectors(self, nr_row, batch_size=1024): + def prune_vectors(self, nr_row, batch_size=8): """Reduce the current vector table to `nr_row` unique entries. Words mapped to the discarded vectors will be remapped to the closest vector among those remaining. @@ -267,22 +267,31 @@ cdef class Vocab: xp = get_array_module(self.vectors.data) # Work in batches, to avoid memory problems. keep = self.vectors.data[:nr_row] + keep_keys = [key for key, row in self.vectors.key2row.items() if row < nr_row] toss = self.vectors.data[nr_row:] # Normalize the vectors, so cosine similarity is just dot product. # Note we can't modify the ones we're keeping in-place... - keep = keep / (xp.linalg.norm(keep)+1e-8) + keep = keep / (xp.linalg.norm(keep, axis=1, keepdims=True)+1e-8) keep = xp.ascontiguousarray(keep.T) neighbours = xp.zeros((toss.shape[0],), dtype='i') + scores = xp.zeros((toss.shape[0],), dtype='f') for i in range(0, toss.shape[0], batch_size): batch = toss[i : i+batch_size] - batch /= xp.linalg.norm(batch)+1e-8 - neighbours[i:i+batch_size] = xp.dot(batch, keep).argmax(axis=1) + batch /= xp.linalg.norm(batch, axis=1, keepdims=True)+1e-8 + sims = xp.dot(batch, keep) + matches = sims.argmax(axis=1) + neighbours[i:i+batch_size] = matches + scores[i:i+batch_size] = sims.max(axis=1) for lex in self: # If we're losing the vector for this word, map it to the nearest # vector we're keeping. if lex.rank >= nr_row: lex.rank = neighbours[lex.rank-nr_row] self.vectors.add(lex.orth, row=lex.rank) + for key in self.vectors.keys: + row = self.vectors.key2row[key] + if row >= nr_row: + self.vectors.key2row[key] = neighbours[row-nr_row] # Make copy, to encourage the original table to be garbage collected. self.vectors.data = xp.ascontiguousarray(self.vectors.data[:nr_row])