diff --git a/Tests/test_file_apng.py b/Tests/test_file_apng.py index 8cb9a814e..579288808 100644 --- a/Tests/test_file_apng.py +++ b/Tests/test_file_apng.py @@ -673,10 +673,16 @@ def test_seek_after_close(): @pytest.mark.parametrize("mode", ("RGBA", "RGB", "P")) -def test_different_modes_in_later_frames(mode, tmp_path): +@pytest.mark.parametrize("default_image", (True, False)) +def test_different_modes_in_later_frames(mode, default_image, tmp_path): test_file = str(tmp_path / "temp.png") im = Image.new("L", (1, 1)) - im.save(test_file, save_all=True, append_images=[Image.new(mode, (1, 1))]) + im.save( + test_file, + save_all=True, + default_image=default_image, + append_images=[Image.new(mode, (1, 1))], + ) with Image.open(test_file) as reloaded: assert reloaded.mode == mode diff --git a/src/PIL/PngImagePlugin.py b/src/PIL/PngImagePlugin.py index 2c7ae68d5..5e5a8cf6a 100644 --- a/src/PIL/PngImagePlugin.py +++ b/src/PIL/PngImagePlugin.py @@ -1105,10 +1105,7 @@ def _write_multiple_frames(im, fp, chunk, rawmode, default_image, append_images) if im_frame.mode == rawmode: im_frame = im_frame.copy() else: - if rawmode == "P": - im_frame = im_frame.convert(rawmode, palette=im.palette) - else: - im_frame = im_frame.convert(rawmode) + im_frame = im_frame.convert(rawmode) encoderinfo = im.encoderinfo.copy() if isinstance(duration, (list, tuple)): encoderinfo["duration"] = duration[frame_count] @@ -1167,6 +1164,8 @@ def _write_multiple_frames(im, fp, chunk, rawmode, default_image, append_images) # default image IDAT (if it exists) if default_image: + if im.mode != rawmode: + im = im.convert(rawmode) ImageFile._save(im, _idat(fp, chunk), [("zip", (0, 0) + im.size, 0, rawmode)]) seq_num = 0 @@ -1228,11 +1227,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False): ) modes = set() append_images = im.encoderinfo.get("append_images", []) - if default_image: - chain = itertools.chain(append_images) - else: - chain = itertools.chain([im], append_images) - for im_seq in chain: + for im_seq in itertools.chain([im], append_images): for im_frame in ImageSequence.Iterator(im_seq): modes.add(im_frame.mode) for mode in ("RGBA", "RGB", "P"):