diff --git a/Tests/test_file_png.py b/Tests/test_file_png.py index c8d441485..bbf5f5772 100644 --- a/Tests/test_file_png.py +++ b/Tests/test_file_png.py @@ -634,6 +634,16 @@ class TestFilePng: with Image.open(out) as reloaded: assert len(reloaded.png.im_palette[1]) == 48 + def test_plte_length(self, tmp_path): + im = Image.new("P", (1, 1)) + im.putpalette((1, 1, 1)) + + out = str(tmp_path / "temp.png") + im.save(str(tmp_path / "temp.png")) + + with Image.open(out) as reloaded: + assert len(reloaded.png.im_palette[1]) == 3 + def test_exif(self): # With an EXIF chunk with Image.open("Tests/images/exif.png") as im: diff --git a/src/PIL/PngImagePlugin.py b/src/PIL/PngImagePlugin.py index 30eb13aa3..07bbc5228 100644 --- a/src/PIL/PngImagePlugin.py +++ b/src/PIL/PngImagePlugin.py @@ -1186,23 +1186,21 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False): # attempt to minimize storage requirements for palette images if "bits" in im.encoderinfo: # number of bits specified by user - colors = 1 << im.encoderinfo["bits"] + colors = min(1 << im.encoderinfo["bits"], 256) else: # check palette contents if im.palette: - colors = max(min(len(im.palette.getdata()[1]) // 3, 256), 2) + colors = max(min(len(im.palette.getdata()[1]) // 3, 256), 1) else: colors = 256 - if colors <= 2: - bits = 1 - elif colors <= 4: - bits = 2 - elif colors <= 16: - bits = 4 - else: - bits = 8 - if bits != 8: + if colors <= 16: + if colors <= 2: + bits = 1 + elif colors <= 4: + bits = 2 + else: + bits = 4 mode = f"{mode};{bits}" # encoder options @@ -1270,7 +1268,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False): chunk(fp, cid, data) if im.mode == "P": - palette_byte_number = (2 ** bits) * 3 + palette_byte_number = colors * 3 palette_bytes = im.im.getpalette("RGB")[:palette_byte_number] while len(palette_bytes) < palette_byte_number: palette_bytes += b"\0" @@ -1281,7 +1279,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False): if transparency or transparency == 0: if im.mode == "P": # limit to actual palette size - alpha_bytes = 2 ** bits + alpha_bytes = colors if isinstance(transparency, bytes): chunk(fp, b"tRNS", transparency[:alpha_bytes]) else: @@ -1302,7 +1300,7 @@ def _save(im, fp, filename, chunk=putchunk, save_all=False): else: if im.mode == "P" and im.im.getpalettemode() == "RGBA": alpha = im.im.getpalette("RGBA", "A") - alpha_bytes = 2 ** bits + alpha_bytes = colors chunk(fp, b"tRNS", alpha[:alpha_bytes]) dpi = im.encoderinfo.get("dpi")