Populate duration before load

This commit is contained in:
Andrew Murray 2025-12-13 21:32:50 +11:00
parent 5e3dc407b4
commit 612de5ae91
6 changed files with 102 additions and 154 deletions

View File

@ -37,7 +37,7 @@ class TestFileJpegXl:
def test_read_rgb(self) -> None:
"""
Can we read a RGB mode Jpeg XL file without error?
Can we read an RGB mode JPEG XL file without error?
Does it have the bits we expect?
"""
@ -52,9 +52,22 @@ class TestFileJpegXl:
# djxl hopper.jxl hopper_jxl_bits.ppm
assert_image_similar_tofile(im, "Tests/images/hopper_jxl_bits.ppm", 1)
def test_read_rgba(self) -> None:
# Generated with `cjxl transparent.png transparent.jxl -q 100 -e 8`
with Image.open("Tests/images/transparent.jxl") as im:
assert im.mode == "RGBA"
assert im.size == (200, 150)
assert im.format == "JPEG XL"
im.load()
im.getdata()
im.tobytes()
assert_image_similar_tofile(im, "Tests/images/transparent.png", 1)
def test_read_i16(self) -> None:
"""
Can we read 16-bit Grayscale Jpeg XL image?
Can we read 16-bit Grayscale JPEG XL image?
"""
with Image.open("Tests/images/jxl/16bit_subcutaneous.cropped.jxl") as im:

View File

@ -1,26 +0,0 @@
from __future__ import annotations
from PIL import Image
from .helper import assert_image_similar_tofile, skip_unless_feature
pytestmark = skip_unless_feature("jpegxl")
def test_read_rgba() -> None:
"""
Can we read an RGBA mode file without error?
Does it have the bits we expect?
"""
# Generated with `cjxl transparent.png transparent.jxl -q 100 -e 8`
with Image.open("Tests/images/transparent.jxl") as im:
assert im.mode == "RGBA"
assert im.size == (200, 150)
assert im.format == "JPEG XL"
im.load()
im.getdata()
im.tobytes()
assert_image_similar_tofile(im, "Tests/images/transparent.png", 1)

View File

@ -21,10 +21,14 @@ def test_n_frames() -> None:
assert im.is_animated
def test_float_duration() -> None:
def test_duration() -> None:
with Image.open("Tests/images/iss634.jxl") as im:
im.load()
assert im.info["duration"] == 70
assert im.info["timestamp"] == 0
im.seek(2)
assert im.info["duration"] == 60
assert im.info["timestamp"] == 140
def test_seek() -> None:
@ -62,8 +66,11 @@ def test_seek() -> None:
def test_seek_errors() -> None:
with Image.open("Tests/images/iss634.jxl") as im:
with pytest.raises(EOFError):
with pytest.raises(EOFError, match="attempt to seek outside sequence"):
im.seek(-1)
with pytest.raises(EOFError):
im.seek(1)
with pytest.raises(EOFError, match="no more images in JPEG XL file"):
im.seek(47)
assert im.tell() == 1

View File

@ -39,7 +39,7 @@ def test_read_exif_metadata() -> None:
with Image.open("Tests/images/flower.jpg") as im_jpeg:
expected_exif = im_jpeg.info["exif"]
# jpeg xl always returns exif without 'Exif\0\0' prefix
# JPEG XL always returns exif without 'Exif\0\0' prefix
assert exif_data == expected_exif[6:]
@ -97,8 +97,8 @@ def test_4_byte_exif(monkeypatch: pytest.MonkeyPatch) -> None:
def __init__(self, b: bytes) -> None:
pass
def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int]:
return ((1, 1), "L", 0, 0, 0, 0)
def get_info(self) -> tuple[tuple[int, int], str, int, int, int, int, int]:
return ((1, 1), "L", 0, 0, 0, 0, 0)
def get_icc(self) -> None:
pass

View File

@ -25,9 +25,7 @@ def _accept(prefix: bytes) -> bool | str:
class JpegXlImageFile(ImageFile.ImageFile):
format = "JPEG XL"
format_description = "JPEG XL image"
__loaded = -1
__logical_frame = 0
__physical_frame = 0
__frame = 0
def _open(self) -> None:
self._decoder = _jpegxl.JpegXlDecoder(self.fp.read())
@ -39,24 +37,27 @@ class JpegXlImageFile(ImageFile.ImageFile):
tps_num,
tps_denom,
self.info["loop"],
tps_duration,
) = self._decoder.get_info()
self._n_frames = None if self.is_animated else 1
self._tps_dur_secs = tps_num / tps_denom if tps_denom != 0 else 1
self.info["duration"] = 1000 * tps_duration * (1 / self._tps_dur_secs)
# TODO: handle libjxl time codes
self.__timestamp = 0
self.info["timestamp"] = 0
if icc := self._decoder.get_icc():
self.info["icc_profile"] = icc
if exif := self._decoder.get_exif():
# jpeg xl does some weird shenanigans when storing exif
# JPEG XL does some weird shenanigans when storing exif
# it omits first 6 bytes of tiff header but adds 4 byte offset instead
if len(exif) > 4:
exif_start_offset = struct.unpack(">I", exif[:4])[0]
self.info["exif"] = exif[exif_start_offset + 4 :]
if xmp := self._decoder.get_xmp():
self.info["xmp"] = xmp
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)]
@property
def n_frames(self) -> int:
@ -67,64 +68,45 @@ class JpegXlImageFile(ImageFile.ImageFile):
return self._n_frames
def _get_next(self) -> tuple[bytes, float, float]:
# Get next frame
next_frame = self._decoder.get_next()
self.__physical_frame += 1
def _get_next(self) -> bytes:
data, tps_duration, is_last = self._decoder.get_next()
# this actually means EOF, errors are raised in _jxl
if next_frame is None:
msg = "failed to decode next frame in JXL file"
raise EOFError(msg)
data, tps_duration, is_last = next_frame
if is_last and self._n_frames is None:
# libjxl said this frame is the last one
self._n_frames = self.__physical_frame
self._n_frames = self.__frame
# duration in milliseconds
duration = 1000 * tps_duration * (1 / self._tps_dur_secs)
timestamp = self.__timestamp
self.__timestamp += duration
self.info["timestamp"] += self.info["duration"]
self.info["duration"] = 1000 * tps_duration * (1 / self._tps_dur_secs)
return data, timestamp, duration
def _seek(self, frame: int) -> None:
if frame == self.__physical_frame:
return # Nothing to do
if frame < self.__physical_frame:
# also rewind libjxl decoder instance
self._decoder.rewind()
self.__physical_frame = 0
self.__loaded = -1
self.__timestamp = 0
while self.__physical_frame < frame:
self._get_next() # Advance to the requested frame
return data
def seek(self, frame: int) -> None:
if self._n_frames is None:
self.n_frames
if not self._seek_check(frame):
return
# Set logical frame to requested position
self.__logical_frame = frame
if frame < self.__frame:
self.__frame = 0
self._decoder.rewind()
self.info["timestamp"] = 0
last_frame = self.__frame
while self.__frame < frame:
self._get_next()
self.__frame += 1
if self._n_frames is not None and self._n_frames < frame:
self.seek(last_frame)
msg = "no more images in JPEG XL file"
raise EOFError(msg)
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)]
def load(self) -> Image.core.PixelAccess | None:
if self.__loaded != self.__logical_frame:
self._seek(self.__logical_frame)
if self.tile:
data = self._get_next()
data, self.info["timestamp"], self.info["duration"] = self._get_next()
self.__loaded = self.__logical_frame
# Set tile
if self.fp and self._exclusive_fp:
self.fp.close()
# this is horribly memory inefficient
# you need probably 2*(raw image plane) bytes of memory
self.fp = BytesIO(data)
self.tile = [ImageFile._Tile("raw", (0, 0) + self.size, 0, self.mode)]
return super().load()
@ -132,7 +114,7 @@ class JpegXlImageFile(ImageFile.ImageFile):
pass
def tell(self) -> int:
return self.__logical_frame
return self.__frame
Image.register_open(JpegXlImageFile.format, JpegXlImageFile, _accept)

View File

@ -18,10 +18,10 @@ void
_jxl_get_pixel_format(JxlPixelFormat *pf, const JxlBasicInfo *bi) {
pf->num_channels = bi->num_color_channels + bi->num_extra_channels;
if (bi->exponent_bits_per_sample > 0 || bi->alpha_exponent_bits > 0) {
pf->data_type = JXL_TYPE_FLOAT; // not yet supported
if (bi->exponent_bits_per_sample) {
pf->data_type = JXL_TYPE_FLOAT;
} else if (bi->bits_per_sample > 8) {
pf->data_type = JXL_TYPE_UINT16; // not yet supported
pf->data_type = JXL_TYPE_UINT16;
} else {
pf->data_type = JXL_TYPE_UINT8;
}
@ -35,41 +35,36 @@ _jxl_get_pixel_format(JxlPixelFormat *pf, const JxlBasicInfo *bi) {
// TODO: floating point mode
char *
_jxl_get_mode(const JxlBasicInfo *bi) {
// 16-bit single channel images are supported
if (bi->bits_per_sample == 16 && bi->num_color_channels == 1 &&
bi->alpha_bits == 0 && !bi->alpha_premultiplied) {
return "I;16";
}
// PIL doesn't support high bit depth images
// it will throw an exception but that's for your own good
// you wouldn't want to see distorted image
if (bi->bits_per_sample != 8) {
return NULL;
}
// image has transparency
if (bi->alpha_bits > 0) {
if (bi->num_color_channels == 3) {
if (bi->alpha_premultiplied) {
return "RGBa";
}
return "RGBA";
}
if (bi->num_color_channels == 1) {
if (bi->alpha_premultiplied) {
return "La";
}
return "LA";
if (bi->num_color_channels == 1 && !bi->alpha_bits) {
if (bi->bits_per_sample == 16) {
return "I;16";
}
}
// image has no transparency
if (bi->num_color_channels == 3) {
return "RGB";
}
if (bi->num_color_channels == 1) {
return "L";
if (bi->bits_per_sample == 8) {
// image has transparency
if (bi->alpha_bits) {
if (bi->num_color_channels == 3) {
if (bi->alpha_premultiplied) {
return "RGBa";
}
return "RGBA";
}
if (bi->num_color_channels == 1) {
if (bi->alpha_premultiplied) {
return "La";
}
return "LA";
}
} else {
// image has no transparency
if (bi->num_color_channels == 3) {
return "RGB";
}
if (bi->num_color_channels == 1) {
return "L";
}
}
}
// could not recognize mode
@ -85,10 +80,10 @@ typedef struct {
Py_ssize_t jxl_data_len; // length of input jxl bitstream
uint8_t *outbuf;
Py_ssize_t outbuf_len;
size_t outbuf_len;
uint8_t *jxl_icc;
Py_ssize_t jxl_icc_len;
size_t jxl_icc_len;
uint8_t *jxl_exif;
Py_ssize_t jxl_exif_len;
uint8_t *jxl_xmp;
@ -262,26 +257,16 @@ decoder_loop_skip_process:
goto end;
}
// got basic info
if (decp->status == JXL_DEC_BASIC_INFO) {
decp->status = JxlDecoderGetBasicInfo(decp->decoder, &decp->basic_info);
_JXL_CHECK("JxlDecoderGetBasicInfo");
_jxl_get_pixel_format(&decp->pixel_format, &decp->basic_info);
if (decp->pixel_format.data_type != JXL_TYPE_UINT8 &&
decp->pixel_format.data_type != JXL_TYPE_UINT16) {
// only 8 bit integer value images are supported for now
PyErr_SetString(
PyExc_NotImplementedError, "unsupported pixel data type"
);
goto end_with_custom_error;
}
decp->mode = _jxl_get_mode(&decp->basic_info);
continue;
}
// got color encoding
if (decp->status == JXL_DEC_COLOR_ENCODING) {
decp->status = JxlDecoderGetICCProfileSize(
decp->decoder, JXL_COLOR_PROFILE_TARGET_DATA, &decp->jxl_icc_len
@ -317,7 +302,7 @@ decoder_loop_skip_process:
continue;
}
size_t cur_compr_box_size;
uint64_t cur_compr_box_size;
decp->status = JxlDecoderGetBoxSizeRaw(decp->decoder, &cur_compr_box_size);
_JXL_CHECK("JxlDecoderGetBoxSizeRaw");
@ -361,12 +346,6 @@ decoder_loop_skip_process:
} while (decp->status != JXL_DEC_FRAME);
// couldn't determine Image mode or it is unsupported
if (!decp->mode) {
PyErr_SetString(PyExc_NotImplementedError, "only 8-bit images are supported");
goto end_with_custom_error;
}
return (PyObject *)decp;
// on success we should never reach here
@ -396,16 +375,21 @@ end_with_custom_error:
PyObject *
_jxl_decoder_get_info(PyObject *self) {
JpegXlDecoderObject *decp = (JpegXlDecoderObject *)self;
JxlFrameHeader fhdr = {};
if (JxlDecoderGetFrameHeader(decp->decoder, &fhdr) != JXL_DEC_SUCCESS) {
PyErr_SetString(PyExc_OSError, "Error determining duration");
return NULL;
}
return Py_BuildValue(
"(II)sOIII",
"(II)sOIIII",
decp->basic_info.xsize,
decp->basic_info.ysize,
decp->mode,
decp->basic_info.have_animation ? Py_True : Py_False,
decp->basic_info.animation.tps_numerator,
decp->basic_info.animation.tps_denominator,
decp->basic_info.animation.num_loops
decp->basic_info.animation.num_loops,
fhdr.duration
);
}
@ -426,11 +410,6 @@ _jxl_decoder_get_next(PyObject *self) {
while (decp->status != JXL_DEC_NEED_IMAGE_OUT_BUFFER) {
decp->status = JxlDecoderProcessInput(decp->decoder);
// every frame was decoded successfully
if (decp->status == JXL_DEC_SUCCESS) {
Py_RETURN_NONE;
}
// this should only occur after rewind
if (decp->status == JXL_DEC_NEED_MORE_INPUT) {
_jxl_decoder_set_input((PyObject *)decp);
@ -454,7 +433,7 @@ _jxl_decoder_get_next(PyObject *self) {
uint8_t *_new_outbuf = realloc(decp->outbuf, decp->outbuf_len);
if (!_new_outbuf) {
PyErr_SetString(PyExc_OSError, "failed to allocate outbuf");
goto end_with_custom_error;
return NULL;
}
decp->outbuf = _new_outbuf;
}
@ -469,7 +448,7 @@ _jxl_decoder_get_next(PyObject *self) {
if (decp->status != JXL_DEC_FULL_IMAGE) {
PyErr_SetString(PyExc_OSError, "failed to read next frame");
goto end_with_custom_error;
return NULL;
}
bytes = PyBytes_FromStringAndSize((char *)(decp->outbuf), decp->outbuf_len);
@ -493,13 +472,6 @@ end:
decp->status
);
PyErr_SetString(PyExc_OSError, err_msg);
end_with_custom_error:
// no need to deallocate anything here
// user can just ignore error
return NULL;
}
PyObject *