For save_all with palette, do not include palette with each frame

This commit is contained in:
Andrew Murray 2021-07-11 22:52:32 +10:00
parent be30792714
commit 43ea81c6db
2 changed files with 45 additions and 4 deletions

View File

@ -821,6 +821,29 @@ def test_palette_save_P(tmp_path):
assert_image_equal(reloaded, im) assert_image_equal(reloaded, im)
def test_palette_save_all_P(tmp_path):
frames = []
colors = ((255, 0, 0), (0, 255, 0))
for color in colors:
frame = Image.new("P", (100, 100))
frame.putpalette(color)
frames.append(frame)
out = str(tmp_path / "temp.gif")
frames[0].save(
out, save_all=True, palette=[255, 0, 0, 0, 255, 0], append_images=frames[1:]
)
with Image.open(out) as im:
# Assert that the frames are correct, and each frame has the same palette
assert_image_equal(im.convert("RGB"), frames[0].convert("RGB"))
assert im.palette.palette == im.global_palette.palette
im.seek(1)
assert_image_equal(im.convert("RGB"), frames[1].convert("RGB"))
assert im.palette.palette == im.global_palette.palette
def test_palette_save_ImagePalette(tmp_path): def test_palette_save_ImagePalette(tmp_path):
# Pass in a different palette, as an ImagePalette.ImagePalette # Pass in a different palette, as an ImagePalette.ImagePalette
# effectively the same as test_palette_save_P # effectively the same as test_palette_save_P

View File

@ -414,9 +414,26 @@ def _normalize_palette(im, palette, info):
source_palette = bytearray(i // 3 for i in range(768)) source_palette = bytearray(i // 3 for i in range(768))
im.palette = ImagePalette.ImagePalette("RGB", palette=source_palette) im.palette = ImagePalette.ImagePalette("RGB", palette=source_palette)
used_palette_colors = _get_optimize(im, info) if palette:
if used_palette_colors is not None: used_palette_colors = []
return im.remap_palette(used_palette_colors, source_palette) for i in range(0, len(source_palette), 3):
source_color = tuple(source_palette[i : i + 3])
try:
index = im.palette.colors[source_color]
except KeyError:
index = None
used_palette_colors.append(index)
for i, index in enumerate(used_palette_colors):
if index is None:
for j in range(len(used_palette_colors)):
if j not in used_palette_colors:
used_palette_colors[i] = j
break
im = im.remap_palette(used_palette_colors)
else:
used_palette_colors = _get_optimize(im, info)
if used_palette_colors is not None:
return im.remap_palette(used_palette_colors, source_palette)
im.palette.palette = source_palette im.palette.palette = source_palette
return im return im
@ -507,7 +524,8 @@ def _write_multiple_frames(im, fp, palette):
offset = (0, 0) offset = (0, 0)
else: else:
# compress difference # compress difference
frame_data["encoderinfo"]["include_color_table"] = True if not palette:
frame_data["encoderinfo"]["include_color_table"] = True
im_frame = im_frame.crop(frame_data["bbox"]) im_frame = im_frame.crop(frame_data["bbox"])
offset = frame_data["bbox"][:2] offset = frame_data["bbox"][:2]