mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Modify Vector.resize to work with cupy and improve resizing (#5216)
* Modify Vector.resize to work with cupy Modify `Vectors.resize` to work with cupy. Modify behavior when resizing to a different vector dimension so that individual vectors are truncated or extended with zeros instead of having the original values filled into the new shape without regard for the original axes. * Update spacy/tests/vocab_vectors/test_vectors.py Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
e53232533b
commit
963bd890c1
|
@ -551,6 +551,7 @@ class Errors(object):
|
|||
"array.")
|
||||
E191 = ("Invalid head: the head token must be from the same doc as the "
|
||||
"token itself.")
|
||||
E192 = ("Unable to resize vectors in place with cupy.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -89,17 +89,28 @@ def test_init_vectors_with_resize_data(data, resize_data):
|
|||
assert v.shape != data.shape
|
||||
|
||||
|
||||
def test_get_vector_resize(strings, data, resize_data):
|
||||
v = Vectors(data=data)
|
||||
v.resize(shape=resize_data.shape)
|
||||
def test_get_vector_resize(strings, data):
|
||||
strings = [hash_string(s) for s in strings]
|
||||
|
||||
# decrease vector dimension (truncate)
|
||||
v = Vectors(data=data)
|
||||
resized_dim = v.shape[1] - 1
|
||||
v.resize(shape=(v.shape[0], resized_dim))
|
||||
for i, string in enumerate(strings):
|
||||
v.add(string, row=i)
|
||||
|
||||
assert list(v[strings[0]]) == list(resize_data[0])
|
||||
assert list(v[strings[0]]) != list(resize_data[1])
|
||||
assert list(v[strings[1]]) != list(resize_data[0])
|
||||
assert list(v[strings[1]]) == list(resize_data[1])
|
||||
assert list(v[strings[0]]) == list(data[0, :resized_dim])
|
||||
assert list(v[strings[1]]) == list(data[1, :resized_dim])
|
||||
|
||||
# increase vector dimension (pad with zeros)
|
||||
v = Vectors(data=data)
|
||||
resized_dim = v.shape[1] + 1
|
||||
v.resize(shape=(v.shape[0], resized_dim))
|
||||
for i, string in enumerate(strings):
|
||||
v.add(string, row=i)
|
||||
|
||||
assert list(v[strings[0]]) == list(data[0]) + [0]
|
||||
assert list(v[strings[1]]) == list(data[1]) + [0]
|
||||
|
||||
|
||||
def test_init_vectors_with_data(strings, data):
|
||||
|
|
|
@ -198,11 +198,17 @@ cdef class Vectors:
|
|||
|
||||
DOCS: https://spacy.io/api/vectors#resize
|
||||
"""
|
||||
xp = get_array_module(self.data)
|
||||
if inplace:
|
||||
if xp == numpy:
|
||||
self.data.resize(shape, refcheck=False)
|
||||
else:
|
||||
xp = get_array_module(self.data)
|
||||
self.data = xp.resize(self.data, shape)
|
||||
raise ValueError(Errors.E192)
|
||||
else:
|
||||
resized_array = xp.zeros(shape, dtype=self.data.dtype)
|
||||
copy_shape = (min(shape[0], self.data.shape[0]), min(shape[1], self.data.shape[1]))
|
||||
resized_array[:copy_shape[0], :copy_shape[1]] = self.data[:copy_shape[0], :copy_shape[1]]
|
||||
self.data = resized_array
|
||||
filled = {row for row in self.key2row.values()}
|
||||
self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
|
||||
removed_items = []
|
||||
|
|
Loading…
Reference in New Issue
Block a user