diff --git a/spacy/errors.py b/spacy/errors.py index c751ad65a..b124fc88c 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -551,6 +551,7 @@ class Errors(object): "array.") E191 = ("Invalid head: the head token must be from the same doc as the " "token itself.") + E192 = ("Unable to resize vectors in place with cupy.") @add_codes diff --git a/spacy/tests/vocab_vectors/test_vectors.py b/spacy/tests/vocab_vectors/test_vectors.py index b688ab9dd..8987b7c89 100644 --- a/spacy/tests/vocab_vectors/test_vectors.py +++ b/spacy/tests/vocab_vectors/test_vectors.py @@ -89,17 +89,28 @@ def test_init_vectors_with_resize_data(data, resize_data): assert v.shape != data.shape -def test_get_vector_resize(strings, data, resize_data): - v = Vectors(data=data) - v.resize(shape=resize_data.shape) +def test_get_vector_resize(strings, data): strings = [hash_string(s) for s in strings] + + # decrease vector dimension (truncate) + v = Vectors(data=data) + resized_dim = v.shape[1] - 1 + v.resize(shape=(v.shape[0], resized_dim)) for i, string in enumerate(strings): v.add(string, row=i) - assert list(v[strings[0]]) == list(resize_data[0]) - assert list(v[strings[0]]) != list(resize_data[1]) - assert list(v[strings[1]]) != list(resize_data[0]) - assert list(v[strings[1]]) == list(resize_data[1]) + assert list(v[strings[0]]) == list(data[0, :resized_dim]) + assert list(v[strings[1]]) == list(data[1, :resized_dim]) + + # increase vector dimension (pad with zeros) + v = Vectors(data=data) + resized_dim = v.shape[1] + 1 + v.resize(shape=(v.shape[0], resized_dim)) + for i, string in enumerate(strings): + v.add(string, row=i) + + assert list(v[strings[0]]) == list(data[0]) + [0] + assert list(v[strings[1]]) == list(data[1]) + [0] def test_init_vectors_with_data(strings, data): diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index f8643640a..5b8512970 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -198,11 +198,17 @@ cdef class Vectors: DOCS: https://spacy.io/api/vectors#resize """ + xp = get_array_module(self.data) if inplace: - self.data.resize(shape, refcheck=False) + if xp == numpy: + self.data.resize(shape, refcheck=False) + else: + raise ValueError(Errors.E192) else: - xp = get_array_module(self.data) - self.data = xp.resize(self.data, shape) + resized_array = xp.zeros(shape, dtype=self.data.dtype) + copy_shape = (min(shape[0], self.data.shape[0]), min(shape[1], self.data.shape[1])) + resized_array[:copy_shape[0], :copy_shape[1]] = self.data[:copy_shape[0], :copy_shape[1]] + self.data = resized_array filled = {row for row in self.key2row.values()} self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled}) removed_items = []