diff --git a/Tests/test_file_jxl.py b/Tests/test_file_jxl.py index 8816702e8..be6c35680 100644 --- a/Tests/test_file_jxl.py +++ b/Tests/test_file_jxl.py @@ -37,7 +37,7 @@ class TestFileJpegXl: def test_read_rgb(self) -> None: """ - Can we read a RGB mode Jpeg XL file without error? + Can we read an RGB mode JPEG XL file without error? Does it have the bits we expect? """ @@ -52,9 +52,22 @@ class TestFileJpegXl: # djxl hopper.jxl hopper_jxl_bits.ppm assert_image_similar_tofile(im, "Tests/images/hopper_jxl_bits.ppm", 1) + def test_read_rgba(self) -> None: + # Generated with `cjxl transparent.png transparent.jxl -q 100 -e 8` + with Image.open("Tests/images/transparent.jxl") as im: + assert im.mode == "RGBA" + assert im.size == (200, 150) + assert im.format == "JPEG XL" + im.load() + im.getdata() + + im.tobytes() + + assert_image_similar_tofile(im, "Tests/images/transparent.png", 1) + def test_read_i16(self) -> None: """ - Can we read 16-bit Grayscale Jpeg XL image? + Can we read 16-bit Grayscale JPEG XL image? """ with Image.open("Tests/images/jxl/16bit_subcutaneous.cropped.jxl") as im: diff --git a/Tests/test_file_jxl_alpha.py b/Tests/test_file_jxl_alpha.py deleted file mode 100644 index a5fa15019..000000000 --- a/Tests/test_file_jxl_alpha.py +++ /dev/null @@ -1,26 +0,0 @@ -from __future__ import annotations - -from PIL import Image - -from .helper import assert_image_similar_tofile, skip_unless_feature - -pytestmark = skip_unless_feature("jpegxl") - - -def test_read_rgba() -> None: - """ - Can we read an RGBA mode file without error? - Does it have the bits we expect? - """ - - # Generated with `cjxl transparent.png transparent.jxl -q 100 -e 8` - with Image.open("Tests/images/transparent.jxl") as im: - assert im.mode == "RGBA" - assert im.size == (200, 150) - assert im.format == "JPEG XL" - im.load() - im.getdata() - - im.tobytes() - - assert_image_similar_tofile(im, "Tests/images/transparent.png", 1) diff --git a/Tests/test_file_jxl_animated.py b/Tests/test_file_jxl_animated.py index 54bcae41e..ebb342d3e 100644 --- a/Tests/test_file_jxl_animated.py +++ b/Tests/test_file_jxl_animated.py @@ -21,10 +21,14 @@ def test_n_frames() -> None: assert im.is_animated -def test_float_duration() -> None: +def test_duration() -> None: with Image.open("Tests/images/iss634.jxl") as im: - im.load() assert im.info["duration"] == 70 + assert im.info["timestamp"] == 0 + + im.seek(2) + assert im.info["duration"] == 60 + assert im.info["timestamp"] == 140 def test_seek() -> None: @@ -62,8 +66,11 @@ def test_seek() -> None: def test_seek_errors() -> None: with Image.open("Tests/images/iss634.jxl") as im: - with pytest.raises(EOFError): + with pytest.raises(EOFError, match="attempt to seek outside sequence"): im.seek(-1) - with pytest.raises(EOFError): + im.seek(1) + with pytest.raises(EOFError, match="no more images in JPEG XL file"): im.seek(47) + + assert im.tell() == 1 diff --git a/Tests/test_file_jxl_metadata.py b/Tests/test_file_jxl_metadata.py index aad4044a5..d68d41791 100644 --- a/Tests/test_file_jxl_metadata.py +++ b/Tests/test_file_jxl_metadata.py @@ -39,7 +39,7 @@ def test_read_exif_metadata() -> None: with Image.open("Tests/images/flower.jpg") as im_jpeg: expected_exif = im_jpeg.info["exif"] - # jpeg xl always returns exif without 'Exif\0\0' prefix + # JPEG XL always returns exif without 'Exif\0\0' prefix assert exif_data == expected_exif[6:] @@ -97,8 +97,8 @@ def test_4_byte_exif(monkeypatch: pytest.MonkeyPatch) -> None: def __init__(self, b: bytes) -> None: pass - def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int]: - return ((1, 1), "L", 0, 0, 0, 0) + def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int, int]: + return ((1, 1), "L", 0, 0, 0, 0, 0) def get_icc(self) -> None: pass diff --git a/src/PIL/JpegXlImagePlugin.py b/src/PIL/JpegXlImagePlugin.py index ef24e43bf..9a5416068 100644 --- a/src/PIL/JpegXlImagePlugin.py +++ b/src/PIL/JpegXlImagePlugin.py @@ -25,9 +25,7 @@ def _accept(prefix: bytes) -> bool | str: class JpegXlImageFile(ImageFile.ImageFile): format = "JPEG XL" format_description = "JPEG XL image" - __loaded = -1 - __logical_frame = 0 - __physical_frame = 0 + __frame = 0 def _open(self) -> None: self._decoder = _jpegxl.JpegXlDecoder(self.fp.read()) @@ -39,24 +37,27 @@ class JpegXlImageFile(ImageFile.ImageFile): tps_num, tps_denom, self.info["loop"], + tps_duration, ) = self._decoder.get_info() self._n_frames = None if self.is_animated else 1 self._tps_dur_secs = tps_num / tps_denom if tps_denom != 0 else 1 + self.info["duration"] = 1000 * tps_duration * (1 / self._tps_dur_secs) # TODO: handle libjxl time codes - self.__timestamp = 0 + self.info["timestamp"] = 0 if icc := self._decoder.get_icc(): self.info["icc_profile"] = icc if exif := self._decoder.get_exif(): - # jpeg xl does some weird shenanigans when storing exif + # JPEG XL does some weird shenanigans when storing exif # it omits first 6 bytes of tiff header but adds 4 byte offset instead if len(exif) > 4: exif_start_offset = struct.unpack(">I", exif[:4])[0] self.info["exif"] = exif[exif_start_offset + 4 :] if xmp := self._decoder.get_xmp(): self.info["xmp"] = xmp + self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)] @property def n_frames(self) -> int: @@ -67,64 +68,45 @@ class JpegXlImageFile(ImageFile.ImageFile): return self._n_frames - def _get_next(self) -> tuple[bytes, float, float]: - # Get next frame - next_frame = self._decoder.get_next() - self.__physical_frame += 1 + def _get_next(self) -> bytes: + data, tps_duration, is_last = self._decoder.get_next() - # this actually means EOF, errors are raised in _jxl - if next_frame is None: - msg = "failed to decode next frame in JXL file" - raise EOFError(msg) - - data, tps_duration, is_last = next_frame if is_last and self._n_frames is None: - # libjxl said this frame is the last one - self._n_frames = self.__physical_frame + self._n_frames = self.__frame # duration in milliseconds - duration = 1000 * tps_duration * (1 / self._tps_dur_secs) - timestamp = self.__timestamp - self.__timestamp += duration + self.info["timestamp"] += self.info["duration"] + self.info["duration"] = 1000 * tps_duration * (1 / self._tps_dur_secs) - return data, timestamp, duration - - def _seek(self, frame: int) -> None: - if frame == self.__physical_frame: - return # Nothing to do - if frame < self.__physical_frame: - # also rewind libjxl decoder instance - self._decoder.rewind() - self.__physical_frame = 0 - self.__loaded = -1 - self.__timestamp = 0 - - while self.__physical_frame < frame: - self._get_next() # Advance to the requested frame + return data def seek(self, frame: int) -> None: - if self._n_frames is None: - self.n_frames if not self._seek_check(frame): return - # Set logical frame to requested position - self.__logical_frame = frame + if frame < self.__frame: + self.__frame = 0 + self._decoder.rewind() + self.info["timestamp"] = 0 + + last_frame = self.__frame + while self.__frame < frame: + self._get_next() + self.__frame += 1 + if self._n_frames is not None and self._n_frames < frame: + self.seek(last_frame) + msg = "no more images in JPEG XL file" + raise EOFError(msg) + + self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)] def load(self) -> Image.core.PixelAccess | None: - if self.__loaded != self.__logical_frame: - self._seek(self.__logical_frame) + if self.tile: + data = self._get_next() - data, self.info["timestamp"], self.info["duration"] = self._get_next() - self.__loaded = self.__logical_frame - - # Set tile if self.fp and self._exclusive_fp: self.fp.close() - # this is horribly memory inefficient - # you need probably 2*(raw image plane) bytes of memory self.fp = BytesIO(data) - self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)] return super().load() @@ -132,7 +114,7 @@ class JpegXlImageFile(ImageFile.ImageFile): pass def tell(self) -> int: - return self.__logical_frame + return self.__frame Image.register_open(JpegXlImageFile.format, JpegXlImageFile, _accept) diff --git a/src/_jpegxl.c b/src/_jpegxl.c index 9f0d6aa75..b812061f6 100644 --- a/src/_jpegxl.c +++ b/src/_jpegxl.c @@ -18,10 +18,10 @@ void _jxl_get_pixel_format(JxlPixelFormat *pf, const JxlBasicInfo *bi) { pf->num_channels = bi->num_color_channels + bi->num_extra_channels; - if (bi->exponent_bits_per_sample > 0 || bi->alpha_exponent_bits > 0) { - pf->data_type = JXL_TYPE_FLOAT; // not yet supported + if (bi->exponent_bits_per_sample) { + pf->data_type = JXL_TYPE_FLOAT; } else if (bi->bits_per_sample > 8) { - pf->data_type = JXL_TYPE_UINT16; // not yet supported + pf->data_type = JXL_TYPE_UINT16; } else { pf->data_type = JXL_TYPE_UINT8; } @@ -35,41 +35,36 @@ _jxl_get_pixel_format(JxlPixelFormat *pf, const JxlBasicInfo *bi) { // TODO: floating point mode char * _jxl_get_mode(const JxlBasicInfo *bi) { - // 16-bit single channel images are supported - if (bi->bits_per_sample == 16 && bi->num_color_channels == 1 && - bi->alpha_bits == 0 && !bi->alpha_premultiplied) { - return "I;16"; - } - - // PIL doesn't support high bit depth images - // it will throw an exception but that's for your own good - // you wouldn't want to see distorted image - if (bi->bits_per_sample != 8) { - return NULL; - } - - // image has transparency - if (bi->alpha_bits > 0) { - if (bi->num_color_channels == 3) { - if (bi->alpha_premultiplied) { - return "RGBa"; - } - return "RGBA"; - } - if (bi->num_color_channels == 1) { - if (bi->alpha_premultiplied) { - return "La"; - } - return "LA"; + if (bi->num_color_channels == 1 && !bi->alpha_bits) { + if (bi->bits_per_sample == 16) { + return "I;16"; } } - // image has no transparency - if (bi->num_color_channels == 3) { - return "RGB"; - } - if (bi->num_color_channels == 1) { - return "L"; + if (bi->bits_per_sample == 8) { + // image has transparency + if (bi->alpha_bits) { + if (bi->num_color_channels == 3) { + if (bi->alpha_premultiplied) { + return "RGBa"; + } + return "RGBA"; + } + if (bi->num_color_channels == 1) { + if (bi->alpha_premultiplied) { + return "La"; + } + return "LA"; + } + } else { + // image has no transparency + if (bi->num_color_channels == 3) { + return "RGB"; + } + if (bi->num_color_channels == 1) { + return "L"; + } + } } // could not recognize mode @@ -85,10 +80,10 @@ typedef struct { Py_ssize_t jxl_data_len; // length of input jxl bitstream uint8_t *outbuf; - Py_ssize_t outbuf_len; + size_t outbuf_len; uint8_t *jxl_icc; - Py_ssize_t jxl_icc_len; + size_t jxl_icc_len; uint8_t *jxl_exif; Py_ssize_t jxl_exif_len; uint8_t *jxl_xmp; @@ -262,26 +257,16 @@ decoder_loop_skip_process: goto end; } - // got basic info if (decp->status == JXL_DEC_BASIC_INFO) { decp->status = JxlDecoderGetBasicInfo(decp->decoder, &decp->basic_info); _JXL_CHECK("JxlDecoderGetBasicInfo"); _jxl_get_pixel_format(&decp->pixel_format, &decp->basic_info); - if (decp->pixel_format.data_type != JXL_TYPE_UINT8 && - decp->pixel_format.data_type != JXL_TYPE_UINT16) { - // only 8 bit integer value images are supported for now - PyErr_SetString( - PyExc_NotImplementedError, "unsupported pixel data type" - ); - goto end_with_custom_error; - } decp->mode = _jxl_get_mode(&decp->basic_info); continue; } - // got color encoding if (decp->status == JXL_DEC_COLOR_ENCODING) { decp->status = JxlDecoderGetICCProfileSize( decp->decoder, JXL_COLOR_PROFILE_TARGET_DATA, &decp->jxl_icc_len @@ -317,7 +302,7 @@ decoder_loop_skip_process: continue; } - size_t cur_compr_box_size; + uint64_t cur_compr_box_size; decp->status = JxlDecoderGetBoxSizeRaw(decp->decoder, &cur_compr_box_size); _JXL_CHECK("JxlDecoderGetBoxSizeRaw"); @@ -361,12 +346,6 @@ decoder_loop_skip_process: } while (decp->status != JXL_DEC_FRAME); - // couldn't determine Image mode or it is unsupported - if (!decp->mode) { - PyErr_SetString(PyExc_NotImplementedError, "only 8-bit images are supported"); - goto end_with_custom_error; - } - return (PyObject *)decp; // on success we should never reach here @@ -396,16 +375,21 @@ end_with_custom_error: PyObject * _jxl_decoder_get_info(PyObject *self) { JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self; - + JxlFrameHeader fhdr = {}; + if (JxlDecoderGetFrameHeader(decp->decoder, &fhdr) != JXL_DEC_SUCCESS) { + PyErr_SetString(PyExc_OSError, "Error determining duration"); + return NULL; + } return Py_BuildValue( - "(II)sOIII", + "(II)sOIIII", decp->basic_info.xsize, decp->basic_info.ysize, decp->mode, decp->basic_info.have_animation ? Py_True : Py_False, decp->basic_info.animation.tps_numerator, decp->basic_info.animation.tps_denominator, - decp->basic_info.animation.num_loops + decp->basic_info.animation.num_loops, + fhdr.duration ); } @@ -426,11 +410,6 @@ _jxl_decoder_get_next(PyObject *self) { while (decp->status != JXL_DEC_NEED_IMAGE_OUT_BUFFER) { decp->status = JxlDecoderProcessInput(decp->decoder); - // every frame was decoded successfully - if (decp->status == JXL_DEC_SUCCESS) { - Py_RETURN_NONE; - } - // this should only occur after rewind if (decp->status == JXL_DEC_NEED_MORE_INPUT) { _jxl_decoder_set_input((PyObject *)decp); @@ -454,7 +433,7 @@ _jxl_decoder_get_next(PyObject *self) { uint8_t *_new_outbuf = realloc(decp->outbuf, decp->outbuf_len); if (!_new_outbuf) { PyErr_SetString(PyExc_OSError, "failed to allocate outbuf"); - goto end_with_custom_error; + return NULL; } decp->outbuf = _new_outbuf; } @@ -469,7 +448,7 @@ _jxl_decoder_get_next(PyObject *self) { if (decp->status != JXL_DEC_FULL_IMAGE) { PyErr_SetString(PyExc_OSError, "failed to read next frame"); - goto end_with_custom_error; + return NULL; } bytes = PyBytes_FromStringAndSize((char *)(decp->outbuf), decp->outbuf_len); @@ -493,13 +472,6 @@ end: decp->status ); PyErr_SetString(PyExc_OSError, err_msg); - -end_with_custom_error: - - // no need to deallocate anything here - // user can just ignore error - - return NULL; } PyObject *