Do not count frames on image open

This commit is contained in:
Andrew Murray 2025-12-06 23:06:23 +11:00
parent 762235cd56
commit 5e3dc407b4
3 changed files with 42 additions and 60 deletions

View File

@ -97,8 +97,8 @@ def test_4_byte_exif(monkeypatch: pytest.MonkeyPatch) -> None:
def __init__(self, b: bytes) -> None:
pass
def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int, int]:
return ((1, 1), "L", 0, 0, 0, 0, 0)
def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int]:
return ((1, 1), "L", 0, 0, 0, 0)
def get_icc(self) -> None:
pass

View File

@ -13,15 +13,6 @@ except ImportError:
SUPPORTED = False
## Future idea:
## it's not known how many frames an animated image has
## by default, _jxl_decoder_new will iterate over all frames without decoding them
## then libjxl decoder is rewinded and we're ready to decode frame by frame
## if OPEN_COUNTS_FRAMES is False, n_frames will be None until the last frame is decoded
## it only applies to animated jpeg xl images
# OPEN_COUNTS_FRAMES = True
def _accept(prefix: bytes) -> bool | str:
is_jxl = prefix.startswith(
(b"\xff\x0a", b"\x00\x00\x00\x0c\x4a\x58\x4c\x20\x0d\x0a\x87\x0a")
@ -34,8 +25,9 @@ def _accept(prefix: bytes) -> bool | str:
class JpegXlImageFile(ImageFile.ImageFile):
format = "JPEG XL"
format_description = "JPEG XL image"
__loaded = 0
__loaded = -1
__logical_frame = 0
__physical_frame = 0
def _open(self) -> None:
self._decoder = _jpegxl.JpegXlDecoder(self.fp.read())
@ -47,16 +39,10 @@ class JpegXlImageFile(ImageFile.ImageFile):
tps_num,
tps_denom,
self.info["loop"],
n_frames,
) = self._decoder.get_info()
self._tps_dur_secs = 1
self.n_frames: int | None = 1
if self.is_animated:
self.n_frames = None
if n_frames > 0:
self.n_frames = n_frames
self._tps_dur_secs = tps_num / tps_denom
self._n_frames = None if self.is_animated else 1
self._tps_dur_secs = tps_num / tps_denom if tps_denom != 0 else 1
# TODO: handle libjxl time codes
self.__timestamp = 0
@ -72,7 +58,14 @@ class JpegXlImageFile(ImageFile.ImageFile):
if xmp := self._decoder.get_xmp():
self.info["xmp"] = xmp
self._rewind()
@property
def n_frames(self) -> int:
if self._n_frames is None:
current = self.tell()
self._n_frames = current + self._decoder.get_frames_left()
self.seek(current)
return self._n_frames
def _get_next(self) -> tuple[bytes, float, float]:
# Get next frame
@ -85,9 +78,9 @@ class JpegXlImageFile(ImageFile.ImageFile):
raise EOFError(msg)
data, tps_duration, is_last = next_frame
if is_last and self.n_frames is None:
if is_last and self._n_frames is None:
# libjxl said this frame is the last one
self.n_frames = self.__physical_frame
self._n_frames = self.__physical_frame
# duration in milliseconds
duration = 1000 * tps_duration * (1 / self._tps_dur_secs)
@ -96,24 +89,22 @@ class JpegXlImageFile(ImageFile.ImageFile):
return data, timestamp, duration
def _rewind(self, hard: bool = False) -> None:
if hard:
self._decoder.rewind()
self.__physical_frame = 0
self.__loaded = -1
self.__timestamp = 0
def _seek(self, frame: int) -> None:
if frame == self.__physical_frame:
return # Nothing to do
if frame < self.__physical_frame:
# also rewind libjxl decoder instance
self._rewind(hard=True)
self._decoder.rewind()
self.__physical_frame = 0
self.__loaded = -1
self.__timestamp = 0
while self.__physical_frame < frame:
self._get_next() # Advance to the requested frame
def seek(self, frame: int) -> None:
if self._n_frames is None:
self.n_frames
if not self._seek_check(frame):
return

View File

@ -98,8 +98,6 @@ typedef struct {
JxlBasicInfo basic_info;
JxlPixelFormat pixel_format;
Py_ssize_t n_frames;
char *mode;
} JpegXlDecoderObject;
@ -166,27 +164,26 @@ _jxl_decoder_rewind(PyObject *self) {
Py_RETURN_NONE;
}
bool
_jxl_decoder_count_frames(PyObject *self) {
JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self;
decp->n_frames = 0;
PyObject *
_jxl_decoder_get_frames_left(PyObject *self) {
int frames_left = 0;
// count all JXL_DEC_NEED_IMAGE_OUT_BUFFER events
JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self;
while (decp->status != JXL_DEC_SUCCESS) {
decp->status = JxlDecoderProcessInput(decp->decoder);
if (decp->status == JXL_DEC_NEED_IMAGE_OUT_BUFFER) {
if (JxlDecoderSkipCurrentFrame(decp->decoder) != JXL_DEC_SUCCESS) {
return false;
PyErr_SetString(PyExc_OSError, "Error when counting frames");
break;
}
decp->n_frames++;
frames_left++;
}
}
JxlDecoderRewind(decp->decoder);
_jxl_decoder_rewind((PyObject *)decp);
return true;
return Py_BuildValue("i", frames_left);
}
PyObject *
@ -206,7 +203,6 @@ _jxl_decoder_new(PyObject *self, PyObject *args) {
decp->jxl_exif_len = 0;
decp->jxl_xmp = NULL;
decp->jxl_xmp_len = 0;
decp->n_frames = 0;
// used for printing more detailed error messages
char *jxl_call_name;
@ -371,14 +367,6 @@ decoder_loop_skip_process:
goto end_with_custom_error;
}
if (decp->basic_info.have_animation) {
// get frame count by iterating over image out events
if (!_jxl_decoder_count_frames((PyObject *)decp)) {
PyErr_SetString(PyExc_OSError, "something went wrong when counting frames");
goto end_with_custom_error;
}
}
return (PyObject *)decp;
// on success we should never reach here
@ -410,15 +398,14 @@ _jxl_decoder_get_info(PyObject *self) {
JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self;
return Py_BuildValue(
"(II)sOIIII",
"(II)sOIII",
decp->basic_info.xsize,
decp->basic_info.ysize,
decp->mode,
decp->basic_info.have_animation ? Py_True : Py_False,
decp->basic_info.animation.tps_numerator,
decp->basic_info.animation.tps_denominator,
decp->basic_info.animation.num_loops,
decp->n_frames
decp->basic_info.animation.num_loops
);
}
@ -432,6 +419,10 @@ _jxl_decoder_get_next(PyObject *self) {
char *jxl_call_name;
// process events until next frame output is ready
if (decp->status == JXL_DEC_FRAME) {
decp->status = JxlDecoderGetFrameHeader(decp->decoder, &fhdr);
_JXL_CHECK("JxlDecoderGetFrameHeader");
}
while (decp->status != JXL_DEC_NEED_IMAGE_OUT_BUFFER) {
decp->status = JxlDecoderProcessInput(decp->decoder);
@ -444,14 +435,10 @@ _jxl_decoder_get_next(PyObject *self) {
if (decp->status == JXL_DEC_NEED_MORE_INPUT) {
_jxl_decoder_set_input((PyObject *)decp);
_JXL_CHECK("JxlDecoderSetInput")
continue;
}
if (decp->status == JXL_DEC_FRAME) {
} else if (decp->status == JXL_DEC_FRAME) {
// decode frame header
decp->status = JxlDecoderGetFrameHeader(decp->decoder, &fhdr);
_JXL_CHECK("JxlDecoderGetFrameHeader");
continue;
}
}
@ -573,6 +560,10 @@ static struct PyMethodDef _jpegxl_decoder_methods[] = {
{"get_icc", (PyCFunction)_jxl_decoder_get_icc, METH_NOARGS, "get_icc"},
{"get_exif", (PyCFunction)_jxl_decoder_get_exif, METH_NOARGS, "get_exif"},
{"get_xmp", (PyCFunction)_jxl_decoder_get_xmp, METH_NOARGS, "get_xmp"},
{"get_frames_left",
(PyCFunction)_jxl_decoder_get_frames_left,
METH_NOARGS,
"get_frames_left"},
{"rewind", (PyCFunction)_jxl_decoder_rewind, METH_NOARGS, "rewind"},
{NULL, NULL} /* sentinel */
};