mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Modify Vector.resize to work with cupy and improve resizing (#5216)
* Modify Vector.resize to work with cupy Modify `Vectors.resize` to work with cupy. Modify behavior when resizing to a different vector dimension so that individual vectors are truncated or extended with zeros instead of having the original values filled into the new shape without regard for the original axes. * Update spacy/tests/vocab_vectors/test_vectors.py Co-Authored-By: Matthew Honnibal <honnibal+gh@gmail.com> Co-authored-by: Matthew Honnibal <honnibal+gh@gmail.com>
This commit is contained in:
parent
e53232533b
commit
963bd890c1
|
@ -551,6 +551,7 @@ class Errors(object):
|
||||||
"array.")
|
"array.")
|
||||||
E191 = ("Invalid head: the head token must be from the same doc as the "
|
E191 = ("Invalid head: the head token must be from the same doc as the "
|
||||||
"token itself.")
|
"token itself.")
|
||||||
|
E192 = ("Unable to resize vectors in place with cupy.")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -89,17 +89,28 @@ def test_init_vectors_with_resize_data(data, resize_data):
|
||||||
assert v.shape != data.shape
|
assert v.shape != data.shape
|
||||||
|
|
||||||
|
|
||||||
def test_get_vector_resize(strings, data, resize_data):
|
def test_get_vector_resize(strings, data):
|
||||||
v = Vectors(data=data)
|
|
||||||
v.resize(shape=resize_data.shape)
|
|
||||||
strings = [hash_string(s) for s in strings]
|
strings = [hash_string(s) for s in strings]
|
||||||
|
|
||||||
|
# decrease vector dimension (truncate)
|
||||||
|
v = Vectors(data=data)
|
||||||
|
resized_dim = v.shape[1] - 1
|
||||||
|
v.resize(shape=(v.shape[0], resized_dim))
|
||||||
for i, string in enumerate(strings):
|
for i, string in enumerate(strings):
|
||||||
v.add(string, row=i)
|
v.add(string, row=i)
|
||||||
|
|
||||||
assert list(v[strings[0]]) == list(resize_data[0])
|
assert list(v[strings[0]]) == list(data[0, :resized_dim])
|
||||||
assert list(v[strings[0]]) != list(resize_data[1])
|
assert list(v[strings[1]]) == list(data[1, :resized_dim])
|
||||||
assert list(v[strings[1]]) != list(resize_data[0])
|
|
||||||
assert list(v[strings[1]]) == list(resize_data[1])
|
# increase vector dimension (pad with zeros)
|
||||||
|
v = Vectors(data=data)
|
||||||
|
resized_dim = v.shape[1] + 1
|
||||||
|
v.resize(shape=(v.shape[0], resized_dim))
|
||||||
|
for i, string in enumerate(strings):
|
||||||
|
v.add(string, row=i)
|
||||||
|
|
||||||
|
assert list(v[strings[0]]) == list(data[0]) + [0]
|
||||||
|
assert list(v[strings[1]]) == list(data[1]) + [0]
|
||||||
|
|
||||||
|
|
||||||
def test_init_vectors_with_data(strings, data):
|
def test_init_vectors_with_data(strings, data):
|
||||||
|
|
|
@ -198,11 +198,17 @@ cdef class Vectors:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/vectors#resize
|
DOCS: https://spacy.io/api/vectors#resize
|
||||||
"""
|
"""
|
||||||
|
xp = get_array_module(self.data)
|
||||||
if inplace:
|
if inplace:
|
||||||
|
if xp == numpy:
|
||||||
self.data.resize(shape, refcheck=False)
|
self.data.resize(shape, refcheck=False)
|
||||||
else:
|
else:
|
||||||
xp = get_array_module(self.data)
|
raise ValueError(Errors.E192)
|
||||||
self.data = xp.resize(self.data, shape)
|
else:
|
||||||
|
resized_array = xp.zeros(shape, dtype=self.data.dtype)
|
||||||
|
copy_shape = (min(shape[0], self.data.shape[0]), min(shape[1], self.data.shape[1]))
|
||||||
|
resized_array[:copy_shape[0], :copy_shape[1]] = self.data[:copy_shape[0], :copy_shape[1]]
|
||||||
|
self.data = resized_array
|
||||||
filled = {row for row in self.key2row.values()}
|
filled = {row for row in self.key2row.values()}
|
||||||
self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
|
self._unset = cppset[int]({row for row in range(shape[0]) if row not in filled})
|
||||||
removed_items = []
|
removed_items = []
|
||||||
|
|
Loading…
Reference in New Issue
Block a user