From 324e548e525f2e7a3e31bb9dde10086fc202e8b5 Mon Sep 17 00:00:00 2001 From: Andrew Murray Date: Fri, 21 Jun 2024 20:41:22 +1000 Subject: [PATCH] Added type hints to ImageFilter --- Tests/test_color_lut.py | 4 +- Tests/test_image_filter.py | 2 +- Tests/test_numpy.py | 18 +++----- docs/reference/internal_modules.rst | 4 ++ src/PIL/ImageFilter.py | 72 ++++++++++++++++++++--------- src/PIL/_typing.py | 10 +++- 6 files changed, 71 insertions(+), 39 deletions(-) diff --git a/Tests/test_color_lut.py b/Tests/test_color_lut.py index c8886a779..00c8995b0 100644 --- a/Tests/test_color_lut.py +++ b/Tests/test_color_lut.py @@ -354,10 +354,10 @@ class TestColorLut3DCoreAPI: class TestColorLut3DFilter: def test_wrong_args(self) -> None: with pytest.raises(ValueError, match="should be either an integer"): - ImageFilter.Color3DLUT("small", [1]) + ImageFilter.Color3DLUT("small", [1]) # type: ignore[arg-type] with pytest.raises(ValueError, match="should be either an integer"): - ImageFilter.Color3DLUT((11, 11), [1]) + ImageFilter.Color3DLUT((11, 11), [1]) # type: ignore[arg-type] with pytest.raises(ValueError, match=r"in \[2, 65\] range"): ImageFilter.Color3DLUT((11, 11, 1), [1]) diff --git a/Tests/test_image_filter.py b/Tests/test_image_filter.py index 1f0644471..412ab44c3 100644 --- a/Tests/test_image_filter.py +++ b/Tests/test_image_filter.py @@ -137,7 +137,7 @@ def test_builtinfilter_p() -> None: builtin_filter = ImageFilter.BuiltinFilter() with pytest.raises(ValueError): - builtin_filter.filter(hopper("P")) + builtin_filter.filter(hopper("P").im) def test_kernel_not_enough_coefficients() -> None: diff --git a/Tests/test_numpy.py b/Tests/test_numpy.py index 36cdb3682..32d2cf985 100644 --- a/Tests/test_numpy.py +++ b/Tests/test_numpy.py @@ -1,17 +1,17 @@ from __future__ import annotations import warnings -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING import pytest -from PIL import Image +from PIL import Image, _typing from .helper import assert_deep_equal, assert_image, hopper, skip_unless_feature if TYPE_CHECKING: import numpy - import numpy.typing + import numpy.typing as npt else: numpy = pytest.importorskip("numpy", reason="NumPy not installed") @@ -19,9 +19,7 @@ TEST_IMAGE_SIZE = (10, 10) def test_numpy_to_image() -> None: - def to_image( - dtype: numpy.typing.DTypeLike, bands: int = 1, boolean: int = 0 - ) -> Image.Image: + def to_image(dtype: npt.DTypeLike, bands: int = 1, boolean: int = 0) -> Image.Image: if bands == 1: if boolean: data = [0, 255] * 50 @@ -106,9 +104,7 @@ def test_1d_array() -> None: assert_image(Image.fromarray(a), "L", (1, 5)) -def _test_img_equals_nparray( - img: Image.Image, np_img: numpy.typing.NDArray[Any] -) -> None: +def _test_img_equals_nparray(img: Image.Image, np_img: _typing.NumpyArray) -> None: assert len(np_img.shape) >= 2 np_size = np_img.shape[1], np_img.shape[0] assert img.size == np_size @@ -166,7 +162,7 @@ def test_save_tiff_uint16() -> None: ("HSV", numpy.uint8), ), ) -def test_to_array(mode: str, dtype: numpy.typing.DTypeLike) -> None: +def test_to_array(mode: str, dtype: npt.DTypeLike) -> None: img = hopper(mode) # Resize to non-square @@ -216,7 +212,7 @@ def test_putdata() -> None: numpy.float64, ), ) -def test_roundtrip_eye(dtype: numpy.typing.DTypeLike) -> None: +def test_roundtrip_eye(dtype: npt.DTypeLike) -> None: arr = numpy.eye(10, dtype=dtype) numpy.testing.assert_array_equal(arr, numpy.array(Image.fromarray(arr))) diff --git a/docs/reference/internal_modules.rst b/docs/reference/internal_modules.rst index 899e4966f..e4cb17c4d 100644 --- a/docs/reference/internal_modules.rst +++ b/docs/reference/internal_modules.rst @@ -33,6 +33,10 @@ Internal Modules Provides a convenient way to import type hints that are not available on some Python versions. +.. py:class:: NumpyArray + + Typing alias. + .. py:class:: StrOrBytesPath Typing alias. diff --git a/src/PIL/ImageFilter.py b/src/PIL/ImageFilter.py index 02288e135..e18b4a446 100644 --- a/src/PIL/ImageFilter.py +++ b/src/PIL/ImageFilter.py @@ -19,12 +19,16 @@ from __future__ import annotations import abc import functools from types import ModuleType -from typing import Any, Sequence +from typing import TYPE_CHECKING, Any, Callable, Sequence, cast + +if TYPE_CHECKING: + from . import _imaging + from ._typing import NumpyArray class Filter: @abc.abstractmethod - def filter(self, image): + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: pass @@ -33,7 +37,9 @@ class MultibandFilter(Filter): class BuiltinFilter(MultibandFilter): - def filter(self, image): + filterargs: tuple[Any, ...] + + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: if image.mode == "P": msg = "cannot filter palette images" raise ValueError(msg) @@ -91,7 +97,7 @@ class RankFilter(Filter): self.size = size self.rank = rank - def filter(self, image): + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: if image.mode == "P": msg = "cannot filter palette images" raise ValueError(msg) @@ -158,7 +164,7 @@ class ModeFilter(Filter): def __init__(self, size: int = 3) -> None: self.size = size - def filter(self, image): + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: return image.modefilter(self.size) @@ -176,9 +182,9 @@ class GaussianBlur(MultibandFilter): def __init__(self, radius: float | Sequence[float] = 2) -> None: self.radius = radius - def filter(self, image): + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: xy = self.radius - if not isinstance(xy, (tuple, list)): + if isinstance(xy, (int, float)): xy = (xy, xy) if xy == (0, 0): return image.copy() @@ -208,9 +214,9 @@ class BoxBlur(MultibandFilter): raise ValueError(msg) self.radius = radius - def filter(self, image): + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: xy = self.radius - if not isinstance(xy, (tuple, list)): + if isinstance(xy, (int, float)): xy = (xy, xy) if xy == (0, 0): return image.copy() @@ -241,7 +247,7 @@ class UnsharpMask(MultibandFilter): self.percent = percent self.threshold = threshold - def filter(self, image): + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: return image.unsharp_mask(self.radius, self.percent, self.threshold) @@ -387,8 +393,13 @@ class Color3DLUT(MultibandFilter): name = "Color 3D LUT" def __init__( - self, size, table, channels: int = 3, target_mode: str | None = None, **kwargs - ): + self, + size: int | tuple[int, int, int], + table: Sequence[float] | Sequence[Sequence[int]] | NumpyArray, + channels: int = 3, + target_mode: str | None = None, + **kwargs: bool, + ) -> None: if channels not in (3, 4): msg = "Only 3 or 4 output channels are supported" raise ValueError(msg) @@ -410,15 +421,16 @@ class Color3DLUT(MultibandFilter): pass if numpy and isinstance(table, numpy.ndarray): + numpy_table: NumpyArray = table if copy_table: - table = table.copy() + numpy_table = numpy_table.copy() - if table.shape in [ + if numpy_table.shape in [ (items * channels,), (items, channels), (size[2], size[1], size[0], channels), ]: - table = table.reshape(items * channels) + table = numpy_table.reshape(items * channels) else: wrong_size = True @@ -428,7 +440,8 @@ class Color3DLUT(MultibandFilter): # Convert to a flat list if table and isinstance(table[0], (list, tuple)): - table, raw_table = [], table + raw_table = cast(Sequence[Sequence[int]], table) + flat_table: list[int] = [] for pixel in raw_table: if len(pixel) != channels: msg = ( @@ -436,7 +449,8 @@ class Color3DLUT(MultibandFilter): f"have a length of {channels}." ) raise ValueError(msg) - table.extend(pixel) + flat_table.extend(pixel) + table = flat_table if wrong_size or len(table) != items * channels: msg = ( @@ -449,7 +463,7 @@ class Color3DLUT(MultibandFilter): self.table = table @staticmethod - def _check_size(size: Any) -> list[int]: + def _check_size(size: Any) -> tuple[int, int, int]: try: _, _, _ = size except ValueError as e: @@ -457,7 +471,7 @@ class Color3DLUT(MultibandFilter): raise ValueError(msg) from e except TypeError: size = (size, size, size) - size = [int(x) for x in size] + size = tuple(int(x) for x in size) for size_1d in size: if not 2 <= size_1d <= 65: msg = "Size should be in [2, 65] range." @@ -465,7 +479,13 @@ class Color3DLUT(MultibandFilter): return size @classmethod - def generate(cls, size, callback, channels=3, target_mode=None): + def generate( + cls, + size: int | tuple[int, int, int], + callback: Callable[[float, float, float], tuple[float, ...]], + channels: int = 3, + target_mode: str | None = None, + ) -> Color3DLUT: """Generates new LUT using provided callback. :param size: Size of the table. Passed to the constructor. @@ -482,7 +502,7 @@ class Color3DLUT(MultibandFilter): msg = "Only 3 or 4 output channels are supported" raise ValueError(msg) - table = [0] * (size_1d * size_2d * size_3d * channels) + table: list[float] = [0] * (size_1d * size_2d * size_3d * channels) idx_out = 0 for b in range(size_3d): for g in range(size_2d): @@ -500,7 +520,13 @@ class Color3DLUT(MultibandFilter): _copy_table=False, ) - def transform(self, callback, with_normals=False, channels=None, target_mode=None): + def transform( + self, + callback: Callable[..., tuple[float, ...]], + with_normals: bool = False, + channels: int | None = None, + target_mode: str | None = None, + ) -> Color3DLUT: """Transforms the table values using provided callback and returns a new LUT with altered values. @@ -564,7 +590,7 @@ class Color3DLUT(MultibandFilter): r.append(f"target_mode={self.mode}") return "<{}>".format(" ".join(r)) - def filter(self, image): + def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore: from . import Image return image.color_lut_3d( diff --git a/src/PIL/_typing.py b/src/PIL/_typing.py index 7075e8672..09ece18fa 100644 --- a/src/PIL/_typing.py +++ b/src/PIL/_typing.py @@ -2,7 +2,14 @@ from __future__ import annotations import os import sys -from typing import Protocol, Sequence, TypeVar, Union +from typing import Any, Protocol, Sequence, TypeVar, Union + +try: + import numpy.typing as npt + + NumpyArray = npt.NDArray[Any] +except ImportError: + pass if sys.version_info >= (3, 10): from typing import TypeGuard @@ -10,7 +17,6 @@ else: try: from typing_extensions import TypeGuard except ImportError: - from typing import Any class TypeGuard: # type: ignore[no-redef] def __class_getitem__(cls, item: Any) -> type[bool]: