mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Fix vectors data on GPU (#7626)
* ensure vectors data is stored on right device * ensure the added vector is on the right device * move vector to numpy before iterating * move best_rows to numpy before iterating
This commit is contained in:
parent
ed561cf428
commit
05bdbe28bb
|
@ -55,7 +55,7 @@ cdef class Vectors:
|
||||||
"""Create a new vector store.
|
"""Create a new vector store.
|
||||||
|
|
||||||
shape (tuple): Size of the table, as (# entries, # columns)
|
shape (tuple): Size of the table, as (# entries, # columns)
|
||||||
data (numpy.ndarray): The vector data.
|
data (numpy.ndarray or cupy.ndarray): The vector data.
|
||||||
keys (iterable): A sequence of keys, aligned with the data.
|
keys (iterable): A sequence of keys, aligned with the data.
|
||||||
name (str): A name to identify the vectors table.
|
name (str): A name to identify the vectors table.
|
||||||
|
|
||||||
|
@ -65,7 +65,8 @@ cdef class Vectors:
|
||||||
if data is None:
|
if data is None:
|
||||||
if shape is None:
|
if shape is None:
|
||||||
shape = (0,0)
|
shape = (0,0)
|
||||||
data = numpy.zeros(shape, dtype="f")
|
ops = get_current_ops()
|
||||||
|
data = ops.xp.zeros(shape, dtype="f")
|
||||||
self.data = data
|
self.data = data
|
||||||
self.key2row = {}
|
self.key2row = {}
|
||||||
if self.data is not None:
|
if self.data is not None:
|
||||||
|
@ -300,6 +301,8 @@ cdef class Vectors:
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E197.format(row=row, key=key))
|
raise ValueError(Errors.E197.format(row=row, key=key))
|
||||||
if vector is not None:
|
if vector is not None:
|
||||||
|
xp = get_array_module(self.data)
|
||||||
|
vector = xp.asarray(vector)
|
||||||
self.data[row] = vector
|
self.data[row] = vector
|
||||||
if self._unset.count(row):
|
if self._unset.count(row):
|
||||||
self._unset.erase(self._unset.find(row))
|
self._unset.erase(self._unset.find(row))
|
||||||
|
@ -321,10 +324,11 @@ cdef class Vectors:
|
||||||
RETURNS (tuple): The most similar entries as a `(keys, best_rows, scores)`
|
RETURNS (tuple): The most similar entries as a `(keys, best_rows, scores)`
|
||||||
tuple.
|
tuple.
|
||||||
"""
|
"""
|
||||||
|
xp = get_array_module(self.data)
|
||||||
filled = sorted(list({row for row in self.key2row.values()}))
|
filled = sorted(list({row for row in self.key2row.values()}))
|
||||||
if len(filled) < n:
|
if len(filled) < n:
|
||||||
raise ValueError(Errors.E198.format(n=n, n_rows=len(filled)))
|
raise ValueError(Errors.E198.format(n=n, n_rows=len(filled)))
|
||||||
xp = get_array_module(self.data)
|
filled = xp.asarray(filled)
|
||||||
|
|
||||||
norms = xp.linalg.norm(self.data[filled], axis=1, keepdims=True)
|
norms = xp.linalg.norm(self.data[filled], axis=1, keepdims=True)
|
||||||
norms[norms == 0] = 1
|
norms[norms == 0] = 1
|
||||||
|
@ -357,8 +361,10 @@ cdef class Vectors:
|
||||||
# Account for numerical error we want to return in range -1, 1
|
# Account for numerical error we want to return in range -1, 1
|
||||||
scores = xp.clip(scores, a_min=-1, a_max=1, out=scores)
|
scores = xp.clip(scores, a_min=-1, a_max=1, out=scores)
|
||||||
row2key = {row: key for key, row in self.key2row.items()}
|
row2key = {row: key for key, row in self.key2row.items()}
|
||||||
|
|
||||||
|
numpy_rows = get_current_ops().to_numpy(best_rows)
|
||||||
keys = xp.asarray(
|
keys = xp.asarray(
|
||||||
[[row2key[row] for row in best_rows[i] if row in row2key]
|
[[row2key[row] for row in numpy_rows[i] if row in row2key]
|
||||||
for i in range(len(queries)) ], dtype="uint64")
|
for i in range(len(queries)) ], dtype="uint64")
|
||||||
return (keys, best_rows, scores)
|
return (keys, best_rows, scores)
|
||||||
|
|
||||||
|
@ -459,7 +465,8 @@ cdef class Vectors:
|
||||||
if hasattr(self.data, "from_bytes"):
|
if hasattr(self.data, "from_bytes"):
|
||||||
self.data.from_bytes()
|
self.data.from_bytes()
|
||||||
else:
|
else:
|
||||||
self.data = srsly.msgpack_loads(b)
|
xp = get_array_module(self.data)
|
||||||
|
self.data = xp.asarray(srsly.msgpack_loads(b))
|
||||||
|
|
||||||
deserializers = {
|
deserializers = {
|
||||||
"key2row": lambda b: self.key2row.update(srsly.msgpack_loads(b)),
|
"key2row": lambda b: self.key2row.update(srsly.msgpack_loads(b)),
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from libc.string cimport memcpy
|
from libc.string cimport memcpy
|
||||||
|
|
||||||
import srsly
|
import srsly
|
||||||
from thinc.api import get_array_module
|
from thinc.api import get_array_module, get_current_ops
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from .lexeme cimport EMPTY_LEXEME, OOV_RANK
|
from .lexeme cimport EMPTY_LEXEME, OOV_RANK
|
||||||
|
@ -293,7 +293,7 @@ cdef class Vocab:
|
||||||
among those remaining.
|
among those remaining.
|
||||||
|
|
||||||
For example, suppose the original table had vectors for the words:
|
For example, suppose the original table had vectors for the words:
|
||||||
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to,
|
['sat', 'cat', 'feline', 'reclined']. If we prune the vector table to
|
||||||
two rows, we would discard the vectors for 'feline' and 'reclined'.
|
two rows, we would discard the vectors for 'feline' and 'reclined'.
|
||||||
These words would then be remapped to the closest remaining vector
|
These words would then be remapped to the closest remaining vector
|
||||||
-- so "feline" would have the same vector as "cat", and "reclined"
|
-- so "feline" would have the same vector as "cat", and "reclined"
|
||||||
|
@ -314,6 +314,7 @@ cdef class Vocab:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/vocab#prune_vectors
|
DOCS: https://spacy.io/api/vocab#prune_vectors
|
||||||
"""
|
"""
|
||||||
|
ops = get_current_ops()
|
||||||
xp = get_array_module(self.vectors.data)
|
xp = get_array_module(self.vectors.data)
|
||||||
# Make sure all vectors are in the vocab
|
# Make sure all vectors are in the vocab
|
||||||
for orth in self.vectors:
|
for orth in self.vectors:
|
||||||
|
@ -329,8 +330,9 @@ cdef class Vocab:
|
||||||
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
|
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
|
||||||
self.vectors = Vectors(data=keep, keys=keys[:nr_row], name=self.vectors.name)
|
self.vectors = Vectors(data=keep, keys=keys[:nr_row], name=self.vectors.name)
|
||||||
syn_keys, syn_rows, scores = self.vectors.most_similar(toss, batch_size=batch_size)
|
syn_keys, syn_rows, scores = self.vectors.most_similar(toss, batch_size=batch_size)
|
||||||
|
syn_keys = ops.to_numpy(syn_keys)
|
||||||
remap = {}
|
remap = {}
|
||||||
for i, key in enumerate(keys[nr_row:]):
|
for i, key in enumerate(ops.to_numpy(keys[nr_row:])):
|
||||||
self.vectors.add(key, row=syn_rows[i][0])
|
self.vectors.add(key, row=syn_rows[i][0])
|
||||||
word = self.strings[key]
|
word = self.strings[key]
|
||||||
synonym = self.strings[syn_keys[i][0]]
|
synonym = self.strings[syn_keys[i][0]]
|
||||||
|
@ -351,7 +353,7 @@ cdef class Vocab:
|
||||||
Defaults to the length of `orth`.
|
Defaults to the length of `orth`.
|
||||||
maxn (int): Maximum n-gram length used for Fasttext's ngram computation.
|
maxn (int): Maximum n-gram length used for Fasttext's ngram computation.
|
||||||
Defaults to the length of `orth`.
|
Defaults to the length of `orth`.
|
||||||
RETURNS (numpy.ndarray): A word vector. Size
|
RETURNS (numpy.ndarray or cupy.ndarray): A word vector. Size
|
||||||
and shape determined by the `vocab.vectors` instance. Usually, a
|
and shape determined by the `vocab.vectors` instance. Usually, a
|
||||||
numpy ndarray of shape (300,) and dtype float32.
|
numpy ndarray of shape (300,) and dtype float32.
|
||||||
|
|
||||||
|
@ -400,7 +402,7 @@ cdef class Vocab:
|
||||||
by string or int ID.
|
by string or int ID.
|
||||||
|
|
||||||
orth (int / unicode): The word.
|
orth (int / unicode): The word.
|
||||||
vector (numpy.ndarray[ndim=1, dtype='float32']): The vector to set.
|
vector (numpy.ndarray or cupy.nadarry[ndim=1, dtype='float32']): The vector to set.
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/vocab#set_vector
|
DOCS: https://spacy.io/api/vocab#set_vector
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue
Block a user