Allow fewer palette entries than the bit depth maximum

This commit is contained in:
Andrew Murray 2021-03-21 09:22:01 +11:00
parent 7235cf3135
commit 754752e78f
2 changed files with 22 additions and 14 deletions

View File

@ -634,6 +634,16 @@ class TestFilePng:
with Image.open(out) as reloaded: with Image.open(out) as reloaded:
assert len(reloaded.png.im_palette[1]) == 48 assert len(reloaded.png.im_palette[1]) == 48
def test_plte_length(self, tmp_path):
im = Image.new("P", (1, 1))
im.putpalette((1, 1, 1))
out = str(tmp_path / "temp.png")
im.save(str(tmp_path / "temp.png"))
with Image.open(out) as reloaded:
assert len(reloaded.png.im_palette[1]) == 3
def test_exif(self): def test_exif(self):
# With an EXIF chunk # With an EXIF chunk
with Image.open("Tests/images/exif.png") as im: with Image.open("Tests/images/exif.png") as im:

View File

@ -1186,23 +1186,21 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False):
# attempt to minimize storage requirements for palette images # attempt to minimize storage requirements for palette images
if "bits" in im.encoderinfo: if "bits" in im.encoderinfo:
# number of bits specified by user # number of bits specified by user
colors = 1 << im.encoderinfo["bits"] colors = min(1 << im.encoderinfo["bits"], 256)
else: else:
# check palette contents # check palette contents
if im.palette: if im.palette:
colors = max(min(len(im.palette.getdata()[1]) // 3, 256), 2) colors = max(min(len(im.palette.getdata()[1]) // 3, 256), 1)
else: else:
colors = 256 colors = 256
if colors <= 2: if colors <= 16:
bits = 1 if colors <= 2:
elif colors <= 4: bits = 1
bits = 2 elif colors <= 4:
elif colors <= 16: bits = 2
bits = 4 else:
else: bits = 4
bits = 8
if bits != 8:
mode = f"{mode};{bits}" mode = f"{mode};{bits}"
# encoder options # encoder options
@ -1270,7 +1268,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False):
chunk(fp, cid, data) chunk(fp, cid, data)
if im.mode == "P": if im.mode == "P":
palette_byte_number = (2 ** bits) * 3 palette_byte_number = colors * 3
palette_bytes = im.im.getpalette("RGB")[:palette_byte_number] palette_bytes = im.im.getpalette("RGB")[:palette_byte_number]
while len(palette_bytes) < palette_byte_number: while len(palette_bytes) < palette_byte_number:
palette_bytes += b"\0" palette_bytes += b"\0"
@ -1281,7 +1279,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False):
if transparency or transparency == 0: if transparency or transparency == 0:
if im.mode == "P": if im.mode == "P":
# limit to actual palette size # limit to actual palette size
alpha_bytes = 2 ** bits alpha_bytes = colors
if isinstance(transparency, bytes): if isinstance(transparency, bytes):
chunk(fp, b"tRNS", transparency[:alpha_bytes]) chunk(fp, b"tRNS", transparency[:alpha_bytes])
else: else:
@ -1302,7 +1300,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False):
else: else:
if im.mode == "P" and im.im.getpalettemode() == "RGBA": if im.mode == "P" and im.im.getpalettemode() == "RGBA":
alpha = im.im.getpalette("RGBA", "A") alpha = im.im.getpalette("RGBA", "A")
alpha_bytes = 2 ** bits alpha_bytes = colors
chunk(fp, b"tRNS", alpha[:alpha_bytes]) chunk(fp, b"tRNS", alpha[:alpha_bytes])
dpi = im.encoderinfo.get("dpi") dpi = im.encoderinfo.get("dpi")