diff --git a/Tests/test_imagepalette.py b/Tests/test_imagepalette.py index 782022f51..6ad21502f 100644 --- a/Tests/test_imagepalette.py +++ b/Tests/test_imagepalette.py @@ -49,6 +49,12 @@ def test_getcolor() -> None: palette.getcolor("unknown") # type: ignore[arg-type] +def test_getcolor_rgba() -> None: + palette = ImagePalette.ImagePalette("RGBA", (1, 2, 3, 4)) + palette.getcolor((5, 6, 7, 8)) + assert palette.palette == b"\x01\x02\x03\x04\x05\x06\x07\x08" + + def test_getcolor_rgba_color_rgb_palette() -> None: palette = ImagePalette.ImagePalette("RGB") diff --git a/src/PIL/ImagePalette.py b/src/PIL/ImagePalette.py index 103697117..eae7aea8f 100644 --- a/src/PIL/ImagePalette.py +++ b/src/PIL/ImagePalette.py @@ -118,7 +118,7 @@ class ImagePalette: ) -> int: if not isinstance(self.palette, bytearray): self._palette = bytearray(self.palette) - index = len(self.palette) // 3 + index = len(self.palette) // len(self.mode) special_colors: tuple[int | tuple[int, ...] | None, ...] = () if image: special_colors = ( @@ -168,11 +168,12 @@ class ImagePalette: index = self._new_color_index(image, e) assert isinstance(self._palette, bytearray) self.colors[color] = index - if index * 3 < len(self.palette): + mode_len = len(self.mode) + if index * mode_len < len(self.palette): self._palette = ( - self._palette[: index * 3] + self._palette[: index * mode_len] + bytes(color) - + self._palette[index * 3 + 3 :] + + self._palette[index * mode_len + mode_len :] ) else: self._palette += bytes(color)