Merge pull request #8156 from radarhere/type_hint_imagefilter

This commit is contained in:
Hugo van Kemenade 2024-06-23 07:27:45 -06:00 committed by GitHub
commit 6879956d17
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 71 additions and 39 deletions

View File

@ -354,10 +354,10 @@ class TestColorLut3DCoreAPI:
class TestColorLut3DFilter: class TestColorLut3DFilter:
def test_wrong_args(self) -> None: def test_wrong_args(self) -> None:
with pytest.raises(ValueError, match="should be either an integer"): with pytest.raises(ValueError, match="should be either an integer"):
ImageFilter.Color3DLUT("small", [1]) ImageFilter.Color3DLUT("small", [1]) # type: ignore[arg-type]
with pytest.raises(ValueError, match="should be either an integer"): with pytest.raises(ValueError, match="should be either an integer"):
ImageFilter.Color3DLUT((11, 11), [1]) ImageFilter.Color3DLUT((11, 11), [1]) # type: ignore[arg-type]
with pytest.raises(ValueError, match=r"in \[2, 65\] range"): with pytest.raises(ValueError, match=r"in \[2, 65\] range"):
ImageFilter.Color3DLUT((11, 11, 1), [1]) ImageFilter.Color3DLUT((11, 11, 1), [1])

View File

@ -137,7 +137,7 @@ def test_builtinfilter_p() -> None:
builtin_filter = ImageFilter.BuiltinFilter() builtin_filter = ImageFilter.BuiltinFilter()
with pytest.raises(ValueError): with pytest.raises(ValueError):
builtin_filter.filter(hopper("P")) builtin_filter.filter(hopper("P").im)
def test_kernel_not_enough_coefficients() -> None: def test_kernel_not_enough_coefficients() -> None:

View File

@ -1,17 +1,17 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING
import pytest import pytest
from PIL import Image from PIL import Image, _typing
from .helper import assert_deep_equal, assert_image, hopper, skip_unless_feature from .helper import assert_deep_equal, assert_image, hopper, skip_unless_feature
if TYPE_CHECKING: if TYPE_CHECKING:
import numpy import numpy
import numpy.typing import numpy.typing as npt
else: else:
numpy = pytest.importorskip("numpy", reason="NumPy not installed") numpy = pytest.importorskip("numpy", reason="NumPy not installed")
@ -19,9 +19,7 @@ TEST_IMAGE_SIZE = (10, 10)
def test_numpy_to_image() -> None: def test_numpy_to_image() -> None:
def to_image( def to_image(dtype: npt.DTypeLike, bands: int = 1, boolean: int = 0) -> Image.Image:
dtype: numpy.typing.DTypeLike, bands: int = 1, boolean: int = 0
) -> Image.Image:
if bands == 1: if bands == 1:
if boolean: if boolean:
data = [0, 255] * 50 data = [0, 255] * 50
@ -106,9 +104,7 @@ def test_1d_array() -> None:
assert_image(Image.fromarray(a), "L", (1, 5)) assert_image(Image.fromarray(a), "L", (1, 5))
def _test_img_equals_nparray( def _test_img_equals_nparray(img: Image.Image, np_img: _typing.NumpyArray) -> None:
img: Image.Image, np_img: numpy.typing.NDArray[Any]
) -> None:
assert len(np_img.shape) >= 2 assert len(np_img.shape) >= 2
np_size = np_img.shape[1], np_img.shape[0] np_size = np_img.shape[1], np_img.shape[0]
assert img.size == np_size assert img.size == np_size
@ -166,7 +162,7 @@ def test_save_tiff_uint16() -> None:
("HSV", numpy.uint8), ("HSV", numpy.uint8),
), ),
) )
def test_to_array(mode: str, dtype: numpy.typing.DTypeLike) -> None: def test_to_array(mode: str, dtype: npt.DTypeLike) -> None:
img = hopper(mode) img = hopper(mode)
# Resize to non-square # Resize to non-square
@ -216,7 +212,7 @@ def test_putdata() -> None:
numpy.float64, numpy.float64,
), ),
) )
def test_roundtrip_eye(dtype: numpy.typing.DTypeLike) -> None: def test_roundtrip_eye(dtype: npt.DTypeLike) -> None:
arr = numpy.eye(10, dtype=dtype) arr = numpy.eye(10, dtype=dtype)
numpy.testing.assert_array_equal(arr, numpy.array(Image.fromarray(arr))) numpy.testing.assert_array_equal(arr, numpy.array(Image.fromarray(arr)))

View File

@ -33,6 +33,10 @@ Internal Modules
Provides a convenient way to import type hints that are not available Provides a convenient way to import type hints that are not available
on some Python versions. on some Python versions.
.. py:class:: NumpyArray
Typing alias.
.. py:class:: StrOrBytesPath .. py:class:: StrOrBytesPath
Typing alias. Typing alias.

View File

@ -19,12 +19,16 @@ from __future__ import annotations
import abc import abc
import functools import functools
from types import ModuleType from types import ModuleType
from typing import Any, Sequence from typing import TYPE_CHECKING, Any, Callable, Sequence, cast
if TYPE_CHECKING:
from . import _imaging
from ._typing import NumpyArray
class Filter: class Filter:
@abc.abstractmethod @abc.abstractmethod
def filter(self, image): def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
pass pass
@ -33,7 +37,9 @@ class MultibandFilter(Filter):
class BuiltinFilter(MultibandFilter): class BuiltinFilter(MultibandFilter):
def filter(self, image): filterargs: tuple[Any, ...]
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
if image.mode == "P": if image.mode == "P":
msg = "cannot filter palette images" msg = "cannot filter palette images"
raise ValueError(msg) raise ValueError(msg)
@ -91,7 +97,7 @@ class RankFilter(Filter):
self.size = size self.size = size
self.rank = rank self.rank = rank
def filter(self, image): def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
if image.mode == "P": if image.mode == "P":
msg = "cannot filter palette images" msg = "cannot filter palette images"
raise ValueError(msg) raise ValueError(msg)
@ -158,7 +164,7 @@ class ModeFilter(Filter):
def __init__(self, size: int = 3) -> None: def __init__(self, size: int = 3) -> None:
self.size = size self.size = size
def filter(self, image): def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
return image.modefilter(self.size) return image.modefilter(self.size)
@ -176,9 +182,9 @@ class GaussianBlur(MultibandFilter):
def __init__(self, radius: float | Sequence[float] = 2) -> None: def __init__(self, radius: float | Sequence[float] = 2) -> None:
self.radius = radius self.radius = radius
def filter(self, image): def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
xy = self.radius xy = self.radius
if not isinstance(xy, (tuple, list)): if isinstance(xy, (int, float)):
xy = (xy, xy) xy = (xy, xy)
if xy == (0, 0): if xy == (0, 0):
return image.copy() return image.copy()
@ -208,9 +214,9 @@ class BoxBlur(MultibandFilter):
raise ValueError(msg) raise ValueError(msg)
self.radius = radius self.radius = radius
def filter(self, image): def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
xy = self.radius xy = self.radius
if not isinstance(xy, (tuple, list)): if isinstance(xy, (int, float)):
xy = (xy, xy) xy = (xy, xy)
if xy == (0, 0): if xy == (0, 0):
return image.copy() return image.copy()
@ -241,7 +247,7 @@ class UnsharpMask(MultibandFilter):
self.percent = percent self.percent = percent
self.threshold = threshold self.threshold = threshold
def filter(self, image): def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
return image.unsharp_mask(self.radius, self.percent, self.threshold) return image.unsharp_mask(self.radius, self.percent, self.threshold)
@ -387,8 +393,13 @@ class Color3DLUT(MultibandFilter):
name = "Color 3D LUT" name = "Color 3D LUT"
def __init__( def __init__(
self, size, table, channels: int = 3, target_mode: str | None = None, **kwargs self,
): size: int | tuple[int, int, int],
table: Sequence[float] | Sequence[Sequence[int]] | NumpyArray,
channels: int = 3,
target_mode: str | None = None,
**kwargs: bool,
) -> None:
if channels not in (3, 4): if channels not in (3, 4):
msg = "Only 3 or 4 output channels are supported" msg = "Only 3 or 4 output channels are supported"
raise ValueError(msg) raise ValueError(msg)
@ -410,15 +421,16 @@ class Color3DLUT(MultibandFilter):
pass pass
if numpy and isinstance(table, numpy.ndarray): if numpy and isinstance(table, numpy.ndarray):
numpy_table: NumpyArray = table
if copy_table: if copy_table:
table = table.copy() numpy_table = numpy_table.copy()
if table.shape in [ if numpy_table.shape in [
(items * channels,), (items * channels,),
(items, channels), (items, channels),
(size[2], size[1], size[0], channels), (size[2], size[1], size[0], channels),
]: ]:
table = table.reshape(items * channels) table = numpy_table.reshape(items * channels)
else: else:
wrong_size = True wrong_size = True
@ -428,7 +440,8 @@ class Color3DLUT(MultibandFilter):
# Convert to a flat list # Convert to a flat list
if table and isinstance(table[0], (list, tuple)): if table and isinstance(table[0], (list, tuple)):
table, raw_table = [], table raw_table = cast(Sequence[Sequence[int]], table)
flat_table: list[int] = []
for pixel in raw_table: for pixel in raw_table:
if len(pixel) != channels: if len(pixel) != channels:
msg = ( msg = (
@ -436,7 +449,8 @@ class Color3DLUT(MultibandFilter):
f"have a length of {channels}." f"have a length of {channels}."
) )
raise ValueError(msg) raise ValueError(msg)
table.extend(pixel) flat_table.extend(pixel)
table = flat_table
if wrong_size or len(table) != items * channels: if wrong_size or len(table) != items * channels:
msg = ( msg = (
@ -449,7 +463,7 @@ class Color3DLUT(MultibandFilter):
self.table = table self.table = table
@staticmethod @staticmethod
def _check_size(size: Any) -> list[int]: def _check_size(size: Any) -> tuple[int, int, int]:
try: try:
_, _, _ = size _, _, _ = size
except ValueError as e: except ValueError as e:
@ -457,7 +471,7 @@ class Color3DLUT(MultibandFilter):
raise ValueError(msg) from e raise ValueError(msg) from e
except TypeError: except TypeError:
size = (size, size, size) size = (size, size, size)
size = [int(x) for x in size] size = tuple(int(x) for x in size)
for size_1d in size: for size_1d in size:
if not 2 <= size_1d <= 65: if not 2 <= size_1d <= 65:
msg = "Size should be in [2, 65] range." msg = "Size should be in [2, 65] range."
@ -465,7 +479,13 @@ class Color3DLUT(MultibandFilter):
return size return size
@classmethod @classmethod
def generate(cls, size, callback, channels=3, target_mode=None): def generate(
cls,
size: int | tuple[int, int, int],
callback: Callable[[float, float, float], tuple[float, ...]],
channels: int = 3,
target_mode: str | None = None,
) -> Color3DLUT:
"""Generates new LUT using provided callback. """Generates new LUT using provided callback.
:param size: Size of the table. Passed to the constructor. :param size: Size of the table. Passed to the constructor.
@ -482,7 +502,7 @@ class Color3DLUT(MultibandFilter):
msg = "Only 3 or 4 output channels are supported" msg = "Only 3 or 4 output channels are supported"
raise ValueError(msg) raise ValueError(msg)
table = [0] * (size_1d * size_2d * size_3d * channels) table: list[float] = [0] * (size_1d * size_2d * size_3d * channels)
idx_out = 0 idx_out = 0
for b in range(size_3d): for b in range(size_3d):
for g in range(size_2d): for g in range(size_2d):
@ -500,7 +520,13 @@ class Color3DLUT(MultibandFilter):
_copy_table=False, _copy_table=False,
) )
def transform(self, callback, with_normals=False, channels=None, target_mode=None): def transform(
self,
callback: Callable[..., tuple[float, ...]],
with_normals: bool = False,
channels: int | None = None,
target_mode: str | None = None,
) -> Color3DLUT:
"""Transforms the table values using provided callback and returns """Transforms the table values using provided callback and returns
a new LUT with altered values. a new LUT with altered values.
@ -564,7 +590,7 @@ class Color3DLUT(MultibandFilter):
r.append(f"target_mode={self.mode}") r.append(f"target_mode={self.mode}")
return "<{}>".format(" ".join(r)) return "<{}>".format(" ".join(r))
def filter(self, image): def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
from . import Image from . import Image
return image.color_lut_3d( return image.color_lut_3d(

View File

@ -2,7 +2,14 @@ from __future__ import annotations
import os import os
import sys import sys
from typing import Protocol, Sequence, TypeVar, Union from typing import Any, Protocol, Sequence, TypeVar, Union
try:
import numpy.typing as npt
NumpyArray = npt.NDArray[Any]
except ImportError:
pass
if sys.version_info >= (3, 10): if sys.version_info >= (3, 10):
from typing import TypeGuard from typing import TypeGuard
@ -10,7 +17,6 @@ else:
try: try:
from typing_extensions import TypeGuard from typing_extensions import TypeGuard
except ImportError: except ImportError:
from typing import Any
class TypeGuard: # type: ignore[no-redef] class TypeGuard: # type: ignore[no-redef]
def __class_getitem__(cls, item: Any) -> type[bool]: def __class_getitem__(cls, item: Any) -> type[bool]: