diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index b1d17a026..a96913109 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -276,7 +276,10 @@ cdef class Vectors: sims = xp.dot(batch, vectors.T) best_rows[i:i+batch_size] = sims.argmax(axis=1) scores[i:i+batch_size] = sims.max(axis=1) - keys = self.find(rows=best_rows) + + xp = get_array_module(self.data) + row2key = {row: key for key, row in self.key2row.items()} + keys = xp.asarray([row2key[row] for row in best_rows], dtype='uint64') return (keys, best_rows, scores) def from_glove(self, path):