Simplified test using assert_image_equal

This commit is contained in:
Andrew Murray 2024-03-21 19:11:19 +11:00
parent fd80b2e1d9
commit c3997050b0

View File

@ -1047,44 +1047,31 @@ class TestImage:
class TestImageBytes: class TestImageBytes:
@pytest.mark.parametrize("mode", image_mode_names) @pytest.mark.parametrize("mode", image_mode_names)
def test_roundtrip_bytes_constructor(self, mode: str): def test_roundtrip_bytes_constructor(self, mode: str):
source_image = hopper(mode) im = hopper(mode)
source_bytes = source_image.tobytes() source_bytes = im.tobytes()
copy_image = Image.frombytes(mode, source_image.size, source_bytes)
assert copy_image.tobytes() == source_bytes reloaded = Image.frombytes(mode, im.size, source_bytes)
assert reloaded.tobytes() == source_bytes
@pytest.mark.parametrize("mode", image_mode_names) @pytest.mark.parametrize("mode", image_mode_names)
def test_roundtrip_bytes_method(self, mode: str): def test_roundtrip_bytes_method(self, mode: str):
source_image = hopper(mode) im = hopper(mode)
source_bytes = source_image.tobytes() source_bytes = im.tobytes()
copy_image = Image.new(mode, source_image.size)
copy_image.frombytes(source_bytes) reloaded = Image.new(mode, im.size)
assert copy_image.tobytes() == source_bytes reloaded.frombytes(source_bytes)
assert reloaded.tobytes() == source_bytes
@pytest.mark.parametrize(("mode", "num_bands", "pixelsize"), image_modes) @pytest.mark.parametrize(("mode", "num_bands", "pixelsize"), image_modes)
def test_getdata_putdata( def test_getdata_putdata(
self, mode: str, num_bands: int, pixelsize: int self, mode: str, num_bands: int, pixelsize: int
): ):
image_byte_size = 2 * 2 * pixelsize start_bytes = bytes(range(2 * 2 * pixelsize))
start_bytes = bytes(range(image_byte_size)) im = Image.frombytes(mode, (2, 2), start_bytes)
image = Image.frombytes(mode, (2, 2), start_bytes)
start_pixels = ( reloaded = Image.new(mode, im.size)
image.getpixel((0, 0)), reloaded.putdata(im.getdata())
image.getpixel((0, 1)), assert_image_equal(im, reloaded)
image.getpixel((1, 0)),
image.getpixel((1, 1)),
)
image.putdata(image.getdata())
end_pixels = (
image.getpixel((0, 0)),
image.getpixel((0, 1)),
image.getpixel((1, 0)),
image.getpixel((1, 1)),
)
assert start_pixels == end_pixels
class MockEncoder(ImageFile.PyEncoder): class MockEncoder(ImageFile.PyEncoder):