diff --git a/Tests/test_file_gimppalette.py b/Tests/test_file_gimppalette.py index c122b37b3..c3d2bfeb7 100644 --- a/Tests/test_file_gimppalette.py +++ b/Tests/test_file_gimppalette.py @@ -16,11 +16,11 @@ def test_sanity() -> None: GimpPaletteFile(fp) with open("Tests/images/bad_palette_file.gpl", "rb") as fp: - with pytest.raises(SyntaxError): + with pytest.raises(SyntaxError, match="bad palette file"): GimpPaletteFile(fp) with open("Tests/images/bad_palette_entry.gpl", "rb") as fp: - with pytest.raises(ValueError): + with pytest.raises(ValueError, match="bad palette entry"): GimpPaletteFile(fp) @@ -40,12 +40,26 @@ def test_get_palette(filename: str, size: int) -> None: assert len(palette) / 3 == size -def test_palette_limit() -> None: +def test_frombytes() -> None: with open("Tests/images/full_gimp_palette.gpl", "rb") as fp: - data = fp.read() + full_data = fp.read() # Test that __init__ only reads 256 entries - data = data.replace(b"#\n", b"") + b" 0 0 0 Index 256" + data = full_data.replace(b"#\n", b"") + b" 0 0 0 Index 256" b = BytesIO(data) palette = GimpPaletteFile(b) assert len(palette.palette) / 3 == 256 + + # Test that frombytes() can read beyond that + palette = GimpPaletteFile.frombytes(data) + assert len(palette.palette) / 3 == 257 + + # Test that __init__ raises an error if a comment is too long + data = full_data[:-1] + b"a" * 100 + b = BytesIO(data) + with pytest.raises(SyntaxError, match="bad palette file"): + palette = GimpPaletteFile(b) + + # Test that frombytes() can read the data regardless + palette = GimpPaletteFile.frombytes(data) + assert len(palette.palette) / 3 == 256 diff --git a/src/PIL/GimpPaletteFile.py b/src/PIL/GimpPaletteFile.py index a87487209..379ffd739 100644 --- a/src/PIL/GimpPaletteFile.py +++ b/src/PIL/GimpPaletteFile.py @@ -16,6 +16,7 @@ from __future__ import annotations import re +from io import BytesIO from typing import IO @@ -24,13 +25,18 @@ class GimpPaletteFile: rawmode = "RGB" - def __init__(self, fp: IO[bytes]) -> None: + def _read(self, fp: IO[bytes], limit: bool = True) -> None: if not fp.readline().startswith(b"GIMP Palette"): msg = "not a GIMP palette file" raise SyntaxError(msg) palette: list[int] = [] - for _ in range(256 + 3): + i = 0 + while True: + if limit and i == 256 + 3: + break + + i += 1 s = fp.readline() if not s: break @@ -38,7 +44,7 @@ class GimpPaletteFile: # skip fields and comment lines if re.match(rb"\w+:|#", s): continue - if len(s) > 100: + if limit and len(s) > 100: msg = "bad palette file" raise SyntaxError(msg) @@ -48,10 +54,19 @@ class GimpPaletteFile: raise ValueError(msg) palette += (int(v[i]) for i in range(3)) - if len(palette) == 768: + if limit and len(palette) == 768: break self.palette = bytes(palette) + def __init__(self, fp: IO[bytes]) -> None: + self._read(fp) + + @classmethod + def frombytes(cls, data: bytes) -> GimpPaletteFile: + self = cls.__new__(cls) + self._read(BytesIO(data), False) + return self + def getpalette(self) -> tuple[bytes, str]: return self.palette, self.rawmode