mirror of
https://github.com/python-pillow/Pillow.git
synced 2025-01-25 00:34:14 +03:00
Merge pull request #8156 from radarhere/type_hint_imagefilter
This commit is contained in:
commit
6879956d17
|
@ -354,10 +354,10 @@ class TestColorLut3DCoreAPI:
|
||||||
class TestColorLut3DFilter:
|
class TestColorLut3DFilter:
|
||||||
def test_wrong_args(self) -> None:
|
def test_wrong_args(self) -> None:
|
||||||
with pytest.raises(ValueError, match="should be either an integer"):
|
with pytest.raises(ValueError, match="should be either an integer"):
|
||||||
ImageFilter.Color3DLUT("small", [1])
|
ImageFilter.Color3DLUT("small", [1]) # type: ignore[arg-type]
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="should be either an integer"):
|
with pytest.raises(ValueError, match="should be either an integer"):
|
||||||
ImageFilter.Color3DLUT((11, 11), [1])
|
ImageFilter.Color3DLUT((11, 11), [1]) # type: ignore[arg-type]
|
||||||
|
|
||||||
with pytest.raises(ValueError, match=r"in \[2, 65\] range"):
|
with pytest.raises(ValueError, match=r"in \[2, 65\] range"):
|
||||||
ImageFilter.Color3DLUT((11, 11, 1), [1])
|
ImageFilter.Color3DLUT((11, 11, 1), [1])
|
||||||
|
|
|
@ -137,7 +137,7 @@ def test_builtinfilter_p() -> None:
|
||||||
builtin_filter = ImageFilter.BuiltinFilter()
|
builtin_filter = ImageFilter.BuiltinFilter()
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
builtin_filter.filter(hopper("P"))
|
builtin_filter.filter(hopper("P").im)
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_not_enough_coefficients() -> None:
|
def test_kernel_not_enough_coefficients() -> None:
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from PIL import Image
|
from PIL import Image, _typing
|
||||||
|
|
||||||
from .helper import assert_deep_equal, assert_image, hopper, skip_unless_feature
|
from .helper import assert_deep_equal, assert_image, hopper, skip_unless_feature
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import numpy
|
import numpy
|
||||||
import numpy.typing
|
import numpy.typing as npt
|
||||||
else:
|
else:
|
||||||
numpy = pytest.importorskip("numpy", reason="NumPy not installed")
|
numpy = pytest.importorskip("numpy", reason="NumPy not installed")
|
||||||
|
|
||||||
|
@ -19,9 +19,7 @@ TEST_IMAGE_SIZE = (10, 10)
|
||||||
|
|
||||||
|
|
||||||
def test_numpy_to_image() -> None:
|
def test_numpy_to_image() -> None:
|
||||||
def to_image(
|
def to_image(dtype: npt.DTypeLike, bands: int = 1, boolean: int = 0) -> Image.Image:
|
||||||
dtype: numpy.typing.DTypeLike, bands: int = 1, boolean: int = 0
|
|
||||||
) -> Image.Image:
|
|
||||||
if bands == 1:
|
if bands == 1:
|
||||||
if boolean:
|
if boolean:
|
||||||
data = [0, 255] * 50
|
data = [0, 255] * 50
|
||||||
|
@ -106,9 +104,7 @@ def test_1d_array() -> None:
|
||||||
assert_image(Image.fromarray(a), "L", (1, 5))
|
assert_image(Image.fromarray(a), "L", (1, 5))
|
||||||
|
|
||||||
|
|
||||||
def _test_img_equals_nparray(
|
def _test_img_equals_nparray(img: Image.Image, np_img: _typing.NumpyArray) -> None:
|
||||||
img: Image.Image, np_img: numpy.typing.NDArray[Any]
|
|
||||||
) -> None:
|
|
||||||
assert len(np_img.shape) >= 2
|
assert len(np_img.shape) >= 2
|
||||||
np_size = np_img.shape[1], np_img.shape[0]
|
np_size = np_img.shape[1], np_img.shape[0]
|
||||||
assert img.size == np_size
|
assert img.size == np_size
|
||||||
|
@ -166,7 +162,7 @@ def test_save_tiff_uint16() -> None:
|
||||||
("HSV", numpy.uint8),
|
("HSV", numpy.uint8),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_to_array(mode: str, dtype: numpy.typing.DTypeLike) -> None:
|
def test_to_array(mode: str, dtype: npt.DTypeLike) -> None:
|
||||||
img = hopper(mode)
|
img = hopper(mode)
|
||||||
|
|
||||||
# Resize to non-square
|
# Resize to non-square
|
||||||
|
@ -216,7 +212,7 @@ def test_putdata() -> None:
|
||||||
numpy.float64,
|
numpy.float64,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
def test_roundtrip_eye(dtype: numpy.typing.DTypeLike) -> None:
|
def test_roundtrip_eye(dtype: npt.DTypeLike) -> None:
|
||||||
arr = numpy.eye(10, dtype=dtype)
|
arr = numpy.eye(10, dtype=dtype)
|
||||||
numpy.testing.assert_array_equal(arr, numpy.array(Image.fromarray(arr)))
|
numpy.testing.assert_array_equal(arr, numpy.array(Image.fromarray(arr)))
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,10 @@ Internal Modules
|
||||||
Provides a convenient way to import type hints that are not available
|
Provides a convenient way to import type hints that are not available
|
||||||
on some Python versions.
|
on some Python versions.
|
||||||
|
|
||||||
|
.. py:class:: NumpyArray
|
||||||
|
|
||||||
|
Typing alias.
|
||||||
|
|
||||||
.. py:class:: StrOrBytesPath
|
.. py:class:: StrOrBytesPath
|
||||||
|
|
||||||
Typing alias.
|
Typing alias.
|
||||||
|
|
|
@ -19,12 +19,16 @@ from __future__ import annotations
|
||||||
import abc
|
import abc
|
||||||
import functools
|
import functools
|
||||||
from types import ModuleType
|
from types import ModuleType
|
||||||
from typing import Any, Sequence
|
from typing import TYPE_CHECKING, Any, Callable, Sequence, cast
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from . import _imaging
|
||||||
|
from ._typing import NumpyArray
|
||||||
|
|
||||||
|
|
||||||
class Filter:
|
class Filter:
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def filter(self, image):
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@ -33,7 +37,9 @@ class MultibandFilter(Filter):
|
||||||
|
|
||||||
|
|
||||||
class BuiltinFilter(MultibandFilter):
|
class BuiltinFilter(MultibandFilter):
|
||||||
def filter(self, image):
|
filterargs: tuple[Any, ...]
|
||||||
|
|
||||||
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
if image.mode == "P":
|
if image.mode == "P":
|
||||||
msg = "cannot filter palette images"
|
msg = "cannot filter palette images"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -91,7 +97,7 @@ class RankFilter(Filter):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.rank = rank
|
self.rank = rank
|
||||||
|
|
||||||
def filter(self, image):
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
if image.mode == "P":
|
if image.mode == "P":
|
||||||
msg = "cannot filter palette images"
|
msg = "cannot filter palette images"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -158,7 +164,7 @@ class ModeFilter(Filter):
|
||||||
def __init__(self, size: int = 3) -> None:
|
def __init__(self, size: int = 3) -> None:
|
||||||
self.size = size
|
self.size = size
|
||||||
|
|
||||||
def filter(self, image):
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
return image.modefilter(self.size)
|
return image.modefilter(self.size)
|
||||||
|
|
||||||
|
|
||||||
|
@ -176,9 +182,9 @@ class GaussianBlur(MultibandFilter):
|
||||||
def __init__(self, radius: float | Sequence[float] = 2) -> None:
|
def __init__(self, radius: float | Sequence[float] = 2) -> None:
|
||||||
self.radius = radius
|
self.radius = radius
|
||||||
|
|
||||||
def filter(self, image):
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
xy = self.radius
|
xy = self.radius
|
||||||
if not isinstance(xy, (tuple, list)):
|
if isinstance(xy, (int, float)):
|
||||||
xy = (xy, xy)
|
xy = (xy, xy)
|
||||||
if xy == (0, 0):
|
if xy == (0, 0):
|
||||||
return image.copy()
|
return image.copy()
|
||||||
|
@ -208,9 +214,9 @@ class BoxBlur(MultibandFilter):
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
self.radius = radius
|
self.radius = radius
|
||||||
|
|
||||||
def filter(self, image):
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
xy = self.radius
|
xy = self.radius
|
||||||
if not isinstance(xy, (tuple, list)):
|
if isinstance(xy, (int, float)):
|
||||||
xy = (xy, xy)
|
xy = (xy, xy)
|
||||||
if xy == (0, 0):
|
if xy == (0, 0):
|
||||||
return image.copy()
|
return image.copy()
|
||||||
|
@ -241,7 +247,7 @@ class UnsharpMask(MultibandFilter):
|
||||||
self.percent = percent
|
self.percent = percent
|
||||||
self.threshold = threshold
|
self.threshold = threshold
|
||||||
|
|
||||||
def filter(self, image):
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
return image.unsharp_mask(self.radius, self.percent, self.threshold)
|
return image.unsharp_mask(self.radius, self.percent, self.threshold)
|
||||||
|
|
||||||
|
|
||||||
|
@ -387,8 +393,13 @@ class Color3DLUT(MultibandFilter):
|
||||||
name = "Color 3D LUT"
|
name = "Color 3D LUT"
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, size, table, channels: int = 3, target_mode: str | None = None, **kwargs
|
self,
|
||||||
):
|
size: int | tuple[int, int, int],
|
||||||
|
table: Sequence[float] | Sequence[Sequence[int]] | NumpyArray,
|
||||||
|
channels: int = 3,
|
||||||
|
target_mode: str | None = None,
|
||||||
|
**kwargs: bool,
|
||||||
|
) -> None:
|
||||||
if channels not in (3, 4):
|
if channels not in (3, 4):
|
||||||
msg = "Only 3 or 4 output channels are supported"
|
msg = "Only 3 or 4 output channels are supported"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
@ -410,15 +421,16 @@ class Color3DLUT(MultibandFilter):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if numpy and isinstance(table, numpy.ndarray):
|
if numpy and isinstance(table, numpy.ndarray):
|
||||||
|
numpy_table: NumpyArray = table
|
||||||
if copy_table:
|
if copy_table:
|
||||||
table = table.copy()
|
numpy_table = numpy_table.copy()
|
||||||
|
|
||||||
if table.shape in [
|
if numpy_table.shape in [
|
||||||
(items * channels,),
|
(items * channels,),
|
||||||
(items, channels),
|
(items, channels),
|
||||||
(size[2], size[1], size[0], channels),
|
(size[2], size[1], size[0], channels),
|
||||||
]:
|
]:
|
||||||
table = table.reshape(items * channels)
|
table = numpy_table.reshape(items * channels)
|
||||||
else:
|
else:
|
||||||
wrong_size = True
|
wrong_size = True
|
||||||
|
|
||||||
|
@ -428,7 +440,8 @@ class Color3DLUT(MultibandFilter):
|
||||||
|
|
||||||
# Convert to a flat list
|
# Convert to a flat list
|
||||||
if table and isinstance(table[0], (list, tuple)):
|
if table and isinstance(table[0], (list, tuple)):
|
||||||
table, raw_table = [], table
|
raw_table = cast(Sequence[Sequence[int]], table)
|
||||||
|
flat_table: list[int] = []
|
||||||
for pixel in raw_table:
|
for pixel in raw_table:
|
||||||
if len(pixel) != channels:
|
if len(pixel) != channels:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -436,7 +449,8 @@ class Color3DLUT(MultibandFilter):
|
||||||
f"have a length of {channels}."
|
f"have a length of {channels}."
|
||||||
)
|
)
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
table.extend(pixel)
|
flat_table.extend(pixel)
|
||||||
|
table = flat_table
|
||||||
|
|
||||||
if wrong_size or len(table) != items * channels:
|
if wrong_size or len(table) != items * channels:
|
||||||
msg = (
|
msg = (
|
||||||
|
@ -449,7 +463,7 @@ class Color3DLUT(MultibandFilter):
|
||||||
self.table = table
|
self.table = table
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_size(size: Any) -> list[int]:
|
def _check_size(size: Any) -> tuple[int, int, int]:
|
||||||
try:
|
try:
|
||||||
_, _, _ = size
|
_, _, _ = size
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
@ -457,7 +471,7 @@ class Color3DLUT(MultibandFilter):
|
||||||
raise ValueError(msg) from e
|
raise ValueError(msg) from e
|
||||||
except TypeError:
|
except TypeError:
|
||||||
size = (size, size, size)
|
size = (size, size, size)
|
||||||
size = [int(x) for x in size]
|
size = tuple(int(x) for x in size)
|
||||||
for size_1d in size:
|
for size_1d in size:
|
||||||
if not 2 <= size_1d <= 65:
|
if not 2 <= size_1d <= 65:
|
||||||
msg = "Size should be in [2, 65] range."
|
msg = "Size should be in [2, 65] range."
|
||||||
|
@ -465,7 +479,13 @@ class Color3DLUT(MultibandFilter):
|
||||||
return size
|
return size
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def generate(cls, size, callback, channels=3, target_mode=None):
|
def generate(
|
||||||
|
cls,
|
||||||
|
size: int | tuple[int, int, int],
|
||||||
|
callback: Callable[[float, float, float], tuple[float, ...]],
|
||||||
|
channels: int = 3,
|
||||||
|
target_mode: str | None = None,
|
||||||
|
) -> Color3DLUT:
|
||||||
"""Generates new LUT using provided callback.
|
"""Generates new LUT using provided callback.
|
||||||
|
|
||||||
:param size: Size of the table. Passed to the constructor.
|
:param size: Size of the table. Passed to the constructor.
|
||||||
|
@ -482,7 +502,7 @@ class Color3DLUT(MultibandFilter):
|
||||||
msg = "Only 3 or 4 output channels are supported"
|
msg = "Only 3 or 4 output channels are supported"
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
table = [0] * (size_1d * size_2d * size_3d * channels)
|
table: list[float] = [0] * (size_1d * size_2d * size_3d * channels)
|
||||||
idx_out = 0
|
idx_out = 0
|
||||||
for b in range(size_3d):
|
for b in range(size_3d):
|
||||||
for g in range(size_2d):
|
for g in range(size_2d):
|
||||||
|
@ -500,7 +520,13 @@ class Color3DLUT(MultibandFilter):
|
||||||
_copy_table=False,
|
_copy_table=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
def transform(self, callback, with_normals=False, channels=None, target_mode=None):
|
def transform(
|
||||||
|
self,
|
||||||
|
callback: Callable[..., tuple[float, ...]],
|
||||||
|
with_normals: bool = False,
|
||||||
|
channels: int | None = None,
|
||||||
|
target_mode: str | None = None,
|
||||||
|
) -> Color3DLUT:
|
||||||
"""Transforms the table values using provided callback and returns
|
"""Transforms the table values using provided callback and returns
|
||||||
a new LUT with altered values.
|
a new LUT with altered values.
|
||||||
|
|
||||||
|
@ -564,7 +590,7 @@ class Color3DLUT(MultibandFilter):
|
||||||
r.append(f"target_mode={self.mode}")
|
r.append(f"target_mode={self.mode}")
|
||||||
return "<{}>".format(" ".join(r))
|
return "<{}>".format(" ".join(r))
|
||||||
|
|
||||||
def filter(self, image):
|
def filter(self, image: _imaging.ImagingCore) -> _imaging.ImagingCore:
|
||||||
from . import Image
|
from . import Image
|
||||||
|
|
||||||
return image.color_lut_3d(
|
return image.color_lut_3d(
|
||||||
|
|
|
@ -2,7 +2,14 @@ from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Protocol, Sequence, TypeVar, Union
|
from typing import Any, Protocol, Sequence, TypeVar, Union
|
||||||
|
|
||||||
|
try:
|
||||||
|
import numpy.typing as npt
|
||||||
|
|
||||||
|
NumpyArray = npt.NDArray[Any]
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
if sys.version_info >= (3, 10):
|
if sys.version_info >= (3, 10):
|
||||||
from typing import TypeGuard
|
from typing import TypeGuard
|
||||||
|
@ -10,7 +17,6 @@ else:
|
||||||
try:
|
try:
|
||||||
from typing_extensions import TypeGuard
|
from typing_extensions import TypeGuard
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
class TypeGuard: # type: ignore[no-redef]
|
class TypeGuard: # type: ignore[no-redef]
|
||||||
def __class_getitem__(cls, item: Any) -> type[bool]:
|
def __class_getitem__(cls, item: Any) -> type[bool]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user