diff --git a/Tests/test_image_array.py b/Tests/test_image_array.py index 4dbbdd218..c4e124ee5 100644 --- a/Tests/test_image_array.py +++ b/Tests/test_image_array.py @@ -14,6 +14,10 @@ def test_toarray(): ai = numpy.array(im.convert(mode)) return ai.shape, ai.dtype.str, ai.nbytes + def test_with_dtype(dtype): + ai = numpy.array(im, dtype=dtype) + assert ai.dtype == dtype + # assert test("1") == ((100, 128), '|b1', 1600)) assert test("L") == ((100, 128), "|u1", 12800) @@ -27,6 +31,9 @@ def test_toarray(): assert test("RGBA") == ((100, 128, 4), "|u1", 51200) assert test("RGBX") == ((100, 128, 4), "|u1", 51200) + test_with_dtype(numpy.float) + test_with_dtype(numpy.uint8) + with Image.open("Tests/images/truncated_jpeg.jpg") as im_truncated: with pytest.raises(OSError): numpy.array(im_truncated) diff --git a/src/PIL/Image.py b/src/PIL/Image.py index 9debddeec..98dfec726 100644 --- a/src/PIL/Image.py +++ b/src/PIL/Image.py @@ -681,7 +681,7 @@ class Image: raise ValueError("Could not save to PNG for display") from e return b.getvalue() - def __array__(self): + def __array__(self, dtype=None): # numpy array interface support import numpy as np @@ -700,7 +700,10 @@ class Image: class ArrayData: __array_interface__ = new - return np.array(ArrayData()) + arr = np.array(ArrayData()) + if dtype is not None: + arr = arr.astype(dtype) + return arr def __getstate__(self): return [self.info, self.mode, self.size, self.getpalette(), self.tobytes()]