Modify Vector.resize to work with cupy and improve resizing (#5216)

* Modify Vector.resize to work with cupy

Modify `Vectors.resize` to work with cupy. Modify behavior when resizing
to a different vector dimension so that individual vectors are truncated
or extended with zeros instead of having the original values filled into
the new shape without regard for the original axes.

* Update spacy/tests/vocab_vectors/test_vectors.py

Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com>

Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
adrianeboyd 2020-03-29 13:51:20 +02:00 committed by GitHub
parent e53232533b
commit 963bd890c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 28 additions and 10 deletions

View File

@ -551,6 +551,7 @@ class Errors(object):
"array.")
E191 = ("Invalid head: the head token must be from the same doc as the "
"token itself.")
E192 = ("Unable to resize vectors in place with cupy.")
@add_codes

View File

@ -89,17 +89,28 @@ def test_init_vectors_with_resize_data(data, resize_data):
assert v.shape != data.shape
def test_get_vector_resize(strings, data, resize_data):
v = Vectors(data=data)
v.resize(shape=resize_data.shape)
def test_get_vector_resize(strings, data):
strings = [hash_string(s) for s in strings]
# decrease vector dimension (truncate)
v = Vectors(data=data)
resized_dim = v.shape[1] - 1
v.resize(shape=(v.shape[0], resized_dim))
for i, string in enumerate(strings):
v.add(string, row=i)
assert list(v[strings[0]]) == list(resize_data[0])
assert list(v[strings[0]]) != list(resize_data[1])
assert list(v[strings[1]]) != list(resize_data[0])
assert list(v[strings[1]]) == list(resize_data[1])
assert list(v[strings[0]]) == list(data[0, :resized_dim])
assert list(v[strings[1]]) == list(data[1, :resized_dim])
# increase vector dimension (pad with zeros)
v = Vectors(data=data)
resized_dim = v.shape[1] + 1
v.resize(shape=(v.shape[0], resized_dim))
for i, string in enumerate(strings):
v.add(string, row=i)
assert list(v[strings[0]]) == list(data[0]) + [0]
assert list(v[strings[1]]) == list(data[1]) + [0]
def test_init_vectors_with_data(strings, data):

View File

@ -198,11 +198,17 @@ cdef class Vectors:
DOCS: https://spacy.io/api/vectors#resize
"""
xp = get_array_module(self.data)
if inplace:
if xp == numpy:
self.data.resize(shape, refcheck=False)
else:
xp = get_array_module(self.data)
self.data = xp.resize(self.data, shape)
raise ValueError(Errors.E192)
else:
resized_array = xp.zeros(shape, dtype=self.data.dtype)
copy_shape = (min(shape[0], self.data.shape[0]), min(shape[1], self.data.shape[1]))
resized_array[:copy_shape[0], :copy_shape[1]] = self.data[:copy_shape[0], :copy_shape[1]]
self.data = resized_array
filled = {row for row in self.key2row.values()}
self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
removed_items = []