diff --git a/Tests/test_file_iptc.py b/Tests/test_file_iptc.py index 3c4c892c8..5a8aaa3ef 100644 --- a/Tests/test_file_iptc.py +++ b/Tests/test_file_iptc.py @@ -2,6 +2,8 @@ from __future__ import annotations from io import BytesIO +import pytest + from PIL import Image, IptcImagePlugin, TiffImagePlugin, TiffTags from .helper import assert_image_equal, hopper @@ -9,21 +11,78 @@ from .helper import assert_image_equal, hopper TEST_FILE = "Tests/images/iptc.jpg" +def create_iptc_image(info: dict[str, int] = {}) -> BytesIO: + def field(tag, value): + return bytes((0x1C,) + tag + (0, len(value))) + value + + data = field((3, 60), bytes((info.get("layers", 1), info.get("component", 0)))) + data += field((3, 120), bytes((info.get("compression", 1),))) + if "band" in info: + data += field((3, 65), bytes((info["band"] + 1,))) + data += field((3, 20), b"\x01") # width + data += field((3, 30), b"\x01") # height + data += field( + (8, 10), + bytes((info.get("data", 0),)), + ) + + return BytesIO(data) + + def test_open() -> None: expected = Image.new("L", (1, 1)) - f = BytesIO( - b"\x1c\x03<\x00\x02\x01\x00\x1c\x03x\x00\x01\x01\x1c\x03\x14\x00\x01\x01" - b"\x1c\x03\x1e\x00\x01\x01\x1c\x08\n\x00\x01\x00" - ) + f = create_iptc_image() with Image.open(f) as im: - assert im.tile == [("iptc", (0, 0, 1, 1), 25, "raw")] + assert im.tile == [("iptc", (0, 0, 1, 1), 25, ("raw", None))] assert_image_equal(im, expected) with Image.open(f) as im: assert im.load() is not None +def test_field_length() -> None: + f = create_iptc_image() + f.seek(28) + f.write(b"\xff") + with pytest.raises(OSError, match="illegal field length in IPTC/NAA file"): + with Image.open(f): + pass + + +@pytest.mark.parametrize("layers, mode", ((3, "RGB"), (4, "CMYK"))) +def test_layers(layers: int, mode: str) -> None: + for band in range(-1, layers): + info = {"layers": layers, "component": 1, "data": 5} + if band != -1: + info["band"] = band + f = create_iptc_image(info) + with Image.open(f) as im: + assert im.mode == mode + + data = [0] * layers + data[max(band, 0)] = 5 + assert im.getpixel((0, 0)) == tuple(data) + + +def test_unknown_compression() -> None: + f = create_iptc_image({"compression": 2}) + with pytest.raises(OSError, match="Unknown IPTC image compression"): + with Image.open(f): + pass + + +def test_getiptcinfo() -> None: + f = create_iptc_image() + with Image.open(f) as im: + assert IptcImagePlugin.getiptcinfo(im) == { + (3, 60): b"\x01\x00", + (3, 120): b"\x01", + (3, 20): b"\x01", + (3, 30): b"\x01", + } + + def test_getiptcinfo_jpg_none() -> None: # Arrange with hopper() as im: diff --git a/src/PIL/IptcImagePlugin.py b/src/PIL/IptcImagePlugin.py index b1fbb1bf1..c28f4dcc7 100644 --- a/src/PIL/IptcImagePlugin.py +++ b/src/PIL/IptcImagePlugin.py @@ -34,10 +34,6 @@ def _i(c: bytes) -> int: return i32((b"\0\0\0\0" + c)[-4:]) -def _i8(c: int | bytes) -> int: - return c if isinstance(c, int) else c[0] - - ## # Image plugin for IPTC/NAA datastreams. To read IPTC/NAA fields # from TIFF and JPEG files, use the getiptcinfo function. @@ -100,16 +96,18 @@ class IptcImageFile(ImageFile.ImageFile): # mode layers = self.info[(3, 60)][0] component = self.info[(3, 60)][1] - if (3, 65) in self.info: - id = self.info[(3, 65)][0] - 1 - else: - id = 0 if layers == 1 and not component: self._mode = "L" - elif layers == 3 and component: - self._mode = "RGB"[id] - elif layers == 4 and component: - self._mode = "CMYK"[id] + band = None + else: + if layers == 3 and component: + self._mode = "RGB" + elif layers == 4 and component: + self._mode = "CMYK" + if (3, 65) in self.info: + band = self.info[(3, 65)][0] - 1 + else: + band = 0 # size self._size = self.getint((3, 20)), self.getint((3, 30)) @@ -124,39 +122,44 @@ class IptcImageFile(ImageFile.ImageFile): # tile if tag == (8, 10): self.tile = [ - ImageFile._Tile("iptc", (0, 0) + self.size, offset, compression) + ImageFile._Tile("iptc", (0, 0) + self.size, offset, (compression, band)) ] def load(self) -> Image.core.PixelAccess | None: - if len(self.tile) != 1 or self.tile[0][0] != "iptc": - return ImageFile.ImageFile.load(self) + if self.tile: + args = self.tile[0].args + assert isinstance(args, tuple) + compression, band = args - offset, compression = self.tile[0][2:] + self.fp.seek(self.tile[0].offset) - self.fp.seek(offset) - - # Copy image data to temporary file - o = BytesIO() - if compression == "raw": - # To simplify access to the extracted file, - # prepend a PPM header - o.write(b"P5\n%d %d\n255\n" % self.size) - while True: - type, size = self.field() - if type != (8, 10): - break - while size > 0: - s = self.fp.read(min(size, 8192)) - if not s: + # Copy image data to temporary file + o = BytesIO() + if compression == "raw": + # To simplify access to the extracted file, + # prepend a PPM header + o.write(b"P5\n%d %d\n255\n" % self.size) + while True: + type, size = self.field() + if type != (8, 10): break - o.write(s) - size -= len(s) + while size > 0: + s = self.fp.read(min(size, 8192)) + if not s: + break + o.write(s) + size -= len(s) - with Image.open(o) as _im: - _im.load() - self.im = _im.im - self.tile = [] - return Image.Image.load(self) + with Image.open(o) as _im: + if band is not None: + bands = [Image.new("L", _im.size)] * Image.getmodebands(self.mode) + bands[band] = _im + _im = Image.merge(self.mode, bands) + else: + _im.load() + self.im = _im.im + self.tile = [] + return ImageFile.ImageFile.load(self) Image.register_open(IptcImageFile.format, IptcImageFile)