diff --git a/Tests/test_imagepalette.py b/Tests/test_imagepalette.py index ee2cfdf9b..ecfbda1d8 100644 --- a/Tests/test_imagepalette.py +++ b/Tests/test_imagepalette.py @@ -49,6 +49,29 @@ def test_getcolor(): palette.getcolor("unknown") +@pytest.mark.parametrize( + "index, palette", + [ + # Test when the palette is not full + (0, ImagePalette.ImagePalette()), + # Test when the palette is full + (255, ImagePalette.ImagePalette("RGB", list(range(256)) * 3)), + ], +) +def test_getcolor_not_special(index, palette): + im = Image.new("P", (1, 1)) + + # Do not use transparency index as a new color + im.info["transparency"] = index + index1 = palette.getcolor((0, 0, 0), im) + assert index1 != index + + # Do not use background index as a new color + im.info["background"] = index1 + index2 = palette.getcolor((0, 0, 1), im) + assert index2 not in (index, index1) + + def test_file(tmp_path): palette = ImagePalette.ImagePalette("RGB", list(range(256)) * 3) diff --git a/src/PIL/ImagePalette.py b/src/PIL/ImagePalette.py index 76c4c46d5..b0c722b29 100644 --- a/src/PIL/ImagePalette.py +++ b/src/PIL/ImagePalette.py @@ -118,11 +118,19 @@ class ImagePalette: if not isinstance(self.palette, bytearray): self._palette = bytearray(self.palette) index = len(self.palette) // 3 + special_colors = () + if image: + special_colors = ( + image.info.get("background"), + image.info.get("transparency"), + ) + while index in special_colors: + index += 1 if index >= 256: if image: # Search for an unused index for i, count in reversed(list(enumerate(image.histogram()))): - if count == 0: + if count == 0 and i not in special_colors: index = i break if index >= 256: