diff --git a/Tests/test_file_spider.py b/Tests/test_file_spider.py index 3b3c3b4a5..03494523b 100644 --- a/Tests/test_file_spider.py +++ b/Tests/test_file_spider.py @@ -14,6 +14,10 @@ from .helper import assert_image_equal, hopper, is_pypy TEST_FILE = "Tests/images/hopper.spider" +def teardown_module() -> None: + del Image.EXTENSION[".spider"] + + def test_sanity() -> None: with Image.open(TEST_FILE) as im: im.load() diff --git a/Tests/test_image.py b/Tests/test_image.py index afc6e8e16..c064939e8 100644 --- a/Tests/test_image.py +++ b/Tests/test_image.py @@ -466,6 +466,9 @@ class TestImage: # Assert assert Image._initialized == 2 + for extension in Image.EXTENSION: + assert extension in Image._EXTENSION_PLUGIN + def test_registered_extensions(self) -> None: # Arrange # Open an image to trigger plugin registration diff --git a/src/PIL/Image.py b/src/PIL/Image.py index eb5616117..8a28b56bd 100644 --- a/src/PIL/Image.py +++ b/src/PIL/Image.py @@ -323,6 +323,108 @@ def getmodebands(mode: str) -> int: _initialized = 0 +# Mapping from file extension to plugin module name for lazy importing +_EXTENSION_PLUGIN: dict[str, str] = { + # Common formats (preinit) + ".bmp": "BmpImagePlugin", + ".dib": "BmpImagePlugin", + ".gif": "GifImagePlugin", + ".jfif": "JpegImagePlugin", + ".jpe": "JpegImagePlugin", + ".jpg": "JpegImagePlugin", + ".jpeg": "JpegImagePlugin", + ".pbm": "PpmImagePlugin", + ".pgm": "PpmImagePlugin", + ".pnm": "PpmImagePlugin", + ".ppm": "PpmImagePlugin", + ".pfm": "PpmImagePlugin", + ".png": "PngImagePlugin", + ".apng": "PngImagePlugin", + # Less common formats (init) + ".avif": "AvifImagePlugin", + ".avifs": "AvifImagePlugin", + ".blp": "BlpImagePlugin", + ".bufr": "BufrStubImagePlugin", + ".cur": "CurImagePlugin", + ".dcx": "DcxImagePlugin", + ".dds": "DdsImagePlugin", + ".ps": "EpsImagePlugin", + ".eps": "EpsImagePlugin", + ".fit": "FitsImagePlugin", + ".fits": "FitsImagePlugin", + ".fli": "FliImagePlugin", + ".flc": "FliImagePlugin", + ".fpx": "FpxImagePlugin", + ".ftc": "FtexImagePlugin", + ".ftu": "FtexImagePlugin", + ".gbr": "GbrImagePlugin", + ".grib": "GribStubImagePlugin", + ".h5": "Hdf5StubImagePlugin", + ".hdf": "Hdf5StubImagePlugin", + ".icns": "IcnsImagePlugin", + ".ico": "IcoImagePlugin", + ".im": "ImImagePlugin", + ".iim": "IptcImagePlugin", + ".jp2": "Jpeg2KImagePlugin", + ".j2k": "Jpeg2KImagePlugin", + ".jpc": "Jpeg2KImagePlugin", + ".jpf": "Jpeg2KImagePlugin", + ".jpx": "Jpeg2KImagePlugin", + ".j2c": "Jpeg2KImagePlugin", + ".mic": "MicImagePlugin", + ".mpg": "MpegImagePlugin", + ".mpeg": "MpegImagePlugin", + ".mpo": "MpoImagePlugin", + ".msp": "MspImagePlugin", + ".palm": "PalmImagePlugin", + ".pcd": "PcdImagePlugin", + ".pcx": "PcxImagePlugin", + ".pdf": "PdfImagePlugin", + ".pxr": "PixarImagePlugin", + ".psd": "PsdImagePlugin", + ".qoi": "QoiImagePlugin", + ".bw": "SgiImagePlugin", + ".rgb": "SgiImagePlugin", + ".rgba": "SgiImagePlugin", + ".sgi": "SgiImagePlugin", + ".ras": "SunImagePlugin", + ".tga": "TgaImagePlugin", + ".icb": "TgaImagePlugin", + ".vda": "TgaImagePlugin", + ".vst": "TgaImagePlugin", + ".tif": "TiffImagePlugin", + ".tiff": "TiffImagePlugin", + ".webp": "WebPImagePlugin", + ".wmf": "WmfImagePlugin", + ".emf": "WmfImagePlugin", + ".xbm": "XbmImagePlugin", + ".xpm": "XpmImagePlugin", +} + + +def _import_plugin_for_extension(ext: str | bytes) -> bool: + """Import only the plugin needed for a specific file extension.""" + if not ext: + return False + + if isinstance(ext, bytes): + ext = ext.decode() + ext = ext.lower() + if ext in EXTENSION: + return True + + plugin = _EXTENSION_PLUGIN.get(ext) + if plugin is None: + return False + + try: + logger.debug("Importing %s", plugin) + __import__(f"{__spec__.parent}.{plugin}", globals(), locals(), []) + return True + except ImportError as e: + logger.debug("Image: failed to import %s: %s", plugin, e) + return False + def preinit() -> None: """ @@ -382,11 +484,10 @@ def init() -> bool: if _initialized >= 2: return False - parent_name = __name__.rpartition(".")[0] for plugin in _plugins: try: logger.debug("Importing %s", plugin) - __import__(f"{parent_name}.{plugin}", globals(), locals(), []) + __import__(f"{__spec__.parent}.{plugin}", globals(), locals(), []) except ImportError as e: logger.debug("Image: failed to import %s: %s", plugin, e) @@ -2535,12 +2636,20 @@ class Image: # only set the name for metadata purposes filename = os.fspath(fp.name) - preinit() + if format: + preinit() + else: + filename_ext = os.path.splitext(filename)[1].lower() + ext = ( + filename_ext.decode() + if isinstance(filename_ext, bytes) + else filename_ext + ) - filename_ext = os.path.splitext(filename)[1].lower() - ext = filename_ext.decode() if isinstance(filename_ext, bytes) else filename_ext + # Try importing only the plugin for this extension first + if not _import_plugin_for_extension(ext): + preinit() - if not format: if ext not in EXTENSION: init() try: @@ -3524,7 +3633,11 @@ def open( prefix = fp.read(16) - preinit() + # Try to import just the plugin needed for this file extension + # before falling back to preinit() which imports common plugins + ext = os.path.splitext(filename)[1] if filename else "" + if not _import_plugin_for_extension(ext): + preinit() warning_messages: list[str] = [] @@ -3560,14 +3673,19 @@ def open( im = _open_core(fp, filename, prefix, formats) if im is None and formats is ID: - checked_formats = ID.copy() - if init(): - im = _open_core( - fp, - filename, - prefix, - tuple(format for format in formats if format not in checked_formats), - ) + # Try preinit (few common plugins) then init (all plugins) + for loader in (preinit, init): + checked_formats = ID.copy() + loader() + if formats != checked_formats: + im = _open_core( + fp, + filename, + prefix, + tuple(f for f in formats if f not in checked_formats), + ) + if im is not None: + break if im: im._exclusive_fp = exclusive_fp diff --git a/src/PIL/SpiderImagePlugin.py b/src/PIL/SpiderImagePlugin.py index 866292243..848dccda5 100644 --- a/src/PIL/SpiderImagePlugin.py +++ b/src/PIL/SpiderImagePlugin.py @@ -290,9 +290,9 @@ def _save(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None: def _save_spider(im: Image.Image, fp: IO[bytes], filename: str | bytes) -> None: # get the filename extension and register it with Image - filename_ext = os.path.splitext(filename)[1] - ext = filename_ext.decode() if isinstance(filename_ext, bytes) else filename_ext - Image.register_extension(SpiderImageFile.format, ext) + if filename_ext := os.path.splitext(filename)[1]: + ext = filename_ext.decode() if isinstance(filename_ext, bytes) else filename_ext + Image.register_extension(SpiderImageFile.format, ext) _save(im, fp, filename)