Fix vectors data on GPU (#7626)

* ensure vectors data is stored on right device

* ensure the added vector is on the right device

* move vector to numpy before iterating

* move best_rows to numpy before iterating
This commit is contained in:
Sofie Van Landeghem 2021-04-19 10:30:03 +02:00 committed by GitHub
parent ed561cf428
commit 05bdbe28bb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 19 additions and 10 deletions

View File

@ -55,7 +55,7 @@ cdef class Vectors:
"""Create a new vector store.
shape (tuple): Size of the table, as (# entries, # columns)
data (numpy.ndarray): The vector data.
data (numpy.ndarray or cupy.ndarray): The vector data.
keys (iterable): A sequence of keys, aligned with the data.
name (str): A name to identify the vectors table.
@ -65,7 +65,8 @@ cdef class Vectors:
if data is None:
if shape is None:
shape = (0,0)
data = numpy.zeros(shape, dtype="f")
ops = get_current_ops()
data = ops.xp.zeros(shape, dtype="f")
self.data = data
self.key2row = {}
if self.data is not None:
@ -300,6 +301,8 @@ cdef class Vectors:
else:
raise ValueError(Errors.E197.format(row=row, key=key))
if vector is not None:
xp = get_array_module(self.data)
vector = xp.asarray(vector)
self.data[row] = vector
if self._unset.count(row):
self._unset.erase(self._unset.find(row))
@ -321,10 +324,11 @@ cdef class Vectors:
RETURNS (tuple): The most similar entries as a `(keys, best_rows, scores)`
tuple.
"""
xp = get_array_module(self.data)
filled = sorted(list({row for row in self.key2row.values()}))
if len(filled) < n:
raise ValueError(Errors.E198.format(n=n, n_rows=len(filled)))
xp = get_array_module(self.data)
filled = xp.asarray(filled)
norms = xp.linalg.norm(self.data[filled], axis=1, keepdims=True)
norms[norms == 0] = 1
@ -357,8 +361,10 @@ cdef class Vectors:
# Account for numerical error we want to return in range -1, 1
scores = xp.clip(scores, a_min=-1, a_max=1, out=scores)
row2key = {row: key for key, row in self.key2row.items()}
numpy_rows = get_current_ops().to_numpy(best_rows)
keys = xp.asarray(
[[row2key[row] for row in best_rows[i] if row in row2key]
[[row2key[row] for row in numpy_rows[i] if row in row2key]
for i in range(len(queries)) ], dtype="uint64")
return (keys, best_rows, scores)
@ -459,7 +465,8 @@ cdef class Vectors:
if hasattr(self.data, "from_bytes"):
self.data.from_bytes()
else:
self.data = srsly.msgpack_loads(b)
xp = get_array_module(self.data)
self.data = xp.asarray(srsly.msgpack_loads(b))
deserializers = {
"key2row": lambda b: self.key2row.update(srsly.msgpack_loads(b)),

View File

@ -2,7 +2,7 @@
from libc.string cimport memcpy
import srsly
from thinc.api import get_array_module
from thinc.api import get_array_module, get_current_ops
import functools
from .lexeme cimport EMPTY_LEXEME, OOV_RANK
@ -293,7 +293,7 @@ cdef class Vocab:
among those remaining.
For example, suppose the original table had vectors for the words:
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to
two rows, we would discard the vectors for 'feline' and 'reclined'.
These words would then be remapped to the closest remaining vector
-- so "feline" would have the same vector as "cat", and "reclined"
@ -314,6 +314,7 @@ cdef class Vocab:
DOCS: https://spacy.io/api/vocab#prune_vectors
"""
ops = get_current_ops()
xp = get_array_module(self.vectors.data)
# Make sure all vectors are in the vocab
for orth in self.vectors:
@ -329,8 +330,9 @@ cdef class Vocab:
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
self.vectors = Vectors(data=keep, keys=keys[:nr_row], name=self.vectors.name)
syn_keys, syn_rows, scores = self.vectors.most_similar(toss, batch_size=batch_size)
syn_keys = ops.to_numpy(syn_keys)
remap = {}
for i, key in enumerate(keys[nr_row:]):
for i, key in enumerate(ops.to_numpy(keys[nr_row:])):
self.vectors.add(key, row=syn_rows[i][0])
word = self.strings[key]
synonym = self.strings[syn_keys[i][0]]
@ -351,7 +353,7 @@ cdef class Vocab:
Defaults to the length of `orth`.
maxn (int): Maximum n-gram length used for Fasttext's ngram computation.
Defaults to the length of `orth`.
RETURNS (numpy.ndarray): A word vector. Size
RETURNS (numpy.ndarray or cupy.ndarray): A word vector. Size
and shape determined by the `vocab.vectors` instance. Usually, a
numpy ndarray of shape (300,) and dtype float32.
@ -400,7 +402,7 @@ cdef class Vocab:
by string or int ID.
orth (int / unicode): The word.
vector (numpy.ndarray[ndim=1, dtype='float32']): The vector to set.
vector (numpy.ndarray or cupy.nadarry[ndim=1, dtype='float32']): The vector to set.
DOCS: https://spacy.io/api/vocab#set_vector
"""