Fix vector pruning

This commit is contained in:
Matthew Honnibal 2017-11-01 02:06:58 +01:00
parent 86eba61fae
commit c48dd0e1d3

View File

@ -275,7 +275,10 @@ cdef class Vectors:
sims = xp.dot(batch, vectors.T)
best_rows[i:i+batch_size] = sims.argmax(axis=1)
scores[i:i+batch_size] = sims.max(axis=1)
keys = self.find(rows=best_rows)
xp = get_array_module(self.data)
row2key = {row: key for key, row in self.key2row.items()}
keys = xp.asarray([row2key[row] for row in best_rows], dtype='uint64')
return (keys, best_rows, scores)
def from_glove(self, path):