From 5e3dc407b40f11b33e507ef895fa9efec39ac036 Mon Sep 17 00:00:00 2001 From: Andrew Murray Date: Sat, 6 Dec 2025 23:06:23 +1100 Subject: [PATCH] Do not count frames on image open --- Tests/test_file_jxl_metadata.py | 4 +-- src/PIL/JpegXlImagePlugin.py | 49 ++++++++++++++------------------- src/_jpegxl.c | 49 ++++++++++++++------------------- 3 files changed, 42 insertions(+), 60 deletions(-) diff --git a/Tests/test_file_jxl_metadata.py b/Tests/test_file_jxl_metadata.py index 267ea7a12..aad4044a5 100644 --- a/Tests/test_file_jxl_metadata.py +++ b/Tests/test_file_jxl_metadata.py @@ -97,8 +97,8 @@ def test_4_byte_exif(monkeypatch: pytest.MonkeyPatch) -> None: def __init__(self, b: bytes) -> None: pass - def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int, int]: - return ((1, 1), "L", 0, 0, 0, 0, 0) + def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int]: + return ((1, 1), "L", 0, 0, 0, 0) def get_icc(self) -> None: pass diff --git a/src/PIL/JpegXlImagePlugin.py b/src/PIL/JpegXlImagePlugin.py index d824449ea..ef24e43bf 100644 --- a/src/PIL/JpegXlImagePlugin.py +++ b/src/PIL/JpegXlImagePlugin.py @@ -13,15 +13,6 @@ except ImportError: SUPPORTED = False -## Future idea: -## it's not known how many frames an animated image has -## by default, _jxl_decoder_new will iterate over all frames without decoding them -## then libjxl decoder is rewinded and we're ready to decode frame by frame -## if OPEN_COUNTS_FRAMES is False, n_frames will be None until the last frame is decoded -## it only applies to animated jpeg xl images -# OPEN_COUNTS_FRAMES = True - - def _accept(prefix: bytes) -> bool | str: is_jxl = prefix.startswith( (b"\xff\x0a", b"\x00\x00\x00\x0c\x4a\x58\x4c\x20\x0d\x0a\x87\x0a") @@ -34,8 +25,9 @@ def _accept(prefix: bytes) -> bool | str: class JpegXlImageFile(ImageFile.ImageFile): format = "JPEG XL" format_description = "JPEG XL image" - __loaded = 0 + __loaded = -1 __logical_frame = 0 + __physical_frame = 0 def _open(self) -> None: self._decoder = _jpegxl.JpegXlDecoder(self.fp.read()) @@ -47,16 +39,10 @@ class JpegXlImageFile(ImageFile.ImageFile): tps_num, tps_denom, self.info["loop"], - n_frames, ) = self._decoder.get_info() - self._tps_dur_secs = 1 - self.n_frames: int | None = 1 - if self.is_animated: - self.n_frames = None - if n_frames > 0: - self.n_frames = n_frames - self._tps_dur_secs = tps_num / tps_denom + self._n_frames = None if self.is_animated else 1 + self._tps_dur_secs = tps_num / tps_denom if tps_denom != 0 else 1 # TODO: handle libjxl time codes self.__timestamp = 0 @@ -72,7 +58,14 @@ class JpegXlImageFile(ImageFile.ImageFile): if xmp := self._decoder.get_xmp(): self.info["xmp"] = xmp - self._rewind() + @property + def n_frames(self) -> int: + if self._n_frames is None: + current = self.tell() + self._n_frames = current + self._decoder.get_frames_left() + self.seek(current) + + return self._n_frames def _get_next(self) -> tuple[bytes, float, float]: # Get next frame @@ -85,9 +78,9 @@ class JpegXlImageFile(ImageFile.ImageFile): raise EOFError(msg) data, tps_duration, is_last = next_frame - if is_last and self.n_frames is None: + if is_last and self._n_frames is None: # libjxl said this frame is the last one - self.n_frames = self.__physical_frame + self._n_frames = self.__physical_frame # duration in milliseconds duration = 1000 * tps_duration * (1 / self._tps_dur_secs) @@ -96,24 +89,22 @@ class JpegXlImageFile(ImageFile.ImageFile): return data, timestamp, duration - def _rewind(self, hard: bool = False) -> None: - if hard: - self._decoder.rewind() - self.__physical_frame = 0 - self.__loaded = -1 - self.__timestamp = 0 - def _seek(self, frame: int) -> None: if frame == self.__physical_frame: return # Nothing to do if frame < self.__physical_frame: # also rewind libjxl decoder instance - self._rewind(hard=True) + self._decoder.rewind() + self.__physical_frame = 0 + self.__loaded = -1 + self.__timestamp = 0 while self.__physical_frame < frame: self._get_next() # Advance to the requested frame def seek(self, frame: int) -> None: + if self._n_frames is None: + self.n_frames if not self._seek_check(frame): return diff --git a/src/_jpegxl.c b/src/_jpegxl.c index ca7fdf7f4..9f0d6aa75 100644 --- a/src/_jpegxl.c +++ b/src/_jpegxl.c @@ -98,8 +98,6 @@ typedef struct { JxlBasicInfo basic_info; JxlPixelFormat pixel_format; - Py_ssize_t n_frames; - char *mode; } JpegXlDecoderObject; @@ -166,27 +164,26 @@ _jxl_decoder_rewind(PyObject *self) { Py_RETURN_NONE; } -bool -_jxl_decoder_count_frames(PyObject *self) { - JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self; - - decp->n_frames = 0; +PyObject * +_jxl_decoder_get_frames_left(PyObject *self) { + int frames_left = 0; // count all JXL_DEC_NEED_IMAGE_OUT_BUFFER events + JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self; while (decp->status != JXL_DEC_SUCCESS) { decp->status = JxlDecoderProcessInput(decp->decoder); if (decp->status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) { if (JxlDecoderSkipCurrentFrame(decp->decoder) != JXL_DEC_SUCCESS) { - return false; + PyErr_SetString(PyExc_OSError, "Error when counting frames"); + break; } - decp->n_frames++; + frames_left++; } } + JxlDecoderRewind(decp->decoder); - _jxl_decoder_rewind((PyObject *)decp); - - return true; + return Py_BuildValue("i", frames_left); } PyObject * @@ -206,7 +203,6 @@ _jxl_decoder_new(PyObject *self, PyObject *args) { decp->jxl_exif_len = 0; decp->jxl_xmp = NULL; decp->jxl_xmp_len = 0; - decp->n_frames = 0; // used for printing more detailed error messages char *jxl_call_name; @@ -371,14 +367,6 @@ decoder_loop_skip_process: goto end_with_custom_error; } - if (decp->basic_info.have_animation) { - // get frame count by iterating over image out events - if (!_jxl_decoder_count_frames((PyObject *)decp)) { - PyErr_SetString(PyExc_OSError, "something went wrong when counting frames"); - goto end_with_custom_error; - } - } - return (PyObject *)decp; // on success we should never reach here @@ -410,15 +398,14 @@ _jxl_decoder_get_info(PyObject *self) { JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self; return Py_BuildValue( - "(II)sOIIII", + "(II)sOIII", decp->basic_info.xsize, decp->basic_info.ysize, decp->mode, decp->basic_info.have_animation ? Py_True : Py_False, decp->basic_info.animation.tps_numerator, decp->basic_info.animation.tps_denominator, - decp->basic_info.animation.num_loops, - decp->n_frames + decp->basic_info.animation.num_loops ); } @@ -432,6 +419,10 @@ _jxl_decoder_get_next(PyObject *self) { char *jxl_call_name; // process events until next frame output is ready + if (decp->status == JXL_DEC_FRAME) { + decp->status = JxlDecoderGetFrameHeader(decp->decoder, &fhdr); + _JXL_CHECK("JxlDecoderGetFrameHeader"); + } while (decp->status != JXL_DEC_NEED_IMAGE_OUT_BUFFER) { decp->status = JxlDecoderProcessInput(decp->decoder); @@ -444,14 +435,10 @@ _jxl_decoder_get_next(PyObject *self) { if (decp->status == JXL_DEC_NEED_MORE_INPUT) { _jxl_decoder_set_input((PyObject *)decp); _JXL_CHECK("JxlDecoderSetInput") - continue; - } - - if (decp->status == JXL_DEC_FRAME) { + } else if (decp->status == JXL_DEC_FRAME) { // decode frame header decp->status = JxlDecoderGetFrameHeader(decp->decoder, &fhdr); _JXL_CHECK("JxlDecoderGetFrameHeader"); - continue; } } @@ -573,6 +560,10 @@ static struct PyMethodDef _jpegxl_decoder_methods[] = { {"get_icc", (PyCFunction)_jxl_decoder_get_icc, METH_NOARGS, "get_icc"}, {"get_exif", (PyCFunction)_jxl_decoder_get_exif, METH_NOARGS, "get_exif"}, {"get_xmp", (PyCFunction)_jxl_decoder_get_xmp, METH_NOARGS, "get_xmp"}, + {"get_frames_left", + (PyCFunction)_jxl_decoder_get_frames_left, + METH_NOARGS, + "get_frames_left"}, {"rewind", (PyCFunction)_jxl_decoder_rewind, METH_NOARGS, "rewind"}, {NULL, NULL} /* sentinel */ };