diff --git a/src/PIL/ImageFilter.py b/src/PIL/ImageFilter.py index 93cd7ad47..d7801bedb 100644 --- a/src/PIL/ImageFilter.py +++ b/src/PIL/ImageFilter.py @@ -15,6 +15,8 @@ # See the README file for information on usage and redistribution. # +from __future__ import division + import functools @@ -323,12 +325,16 @@ class Color3DLUT(MultibandFilter): """ name = "Color 3D LUT" - def __init__(self, size, table, channels=3, target_mode=None): + def __init__(self, size, table, channels=3, target_mode=None, **kwargs): self.size = size = self._check_size(size) self.channels = channels self.mode = target_mode - table = list(table) + # Hidden flag `_copy_table=False` could be used to avoid extra copying + # of the table if the table is specially made for the constructor. + if kwargs.get('_copy_table', True): + table = list(table) + # Convert to a flat list if table and isinstance(table[0], (list, tuple)): table, raw_table = [], table @@ -371,20 +377,45 @@ class Color3DLUT(MultibandFilter): three color channels. Will be called ``size**3`` times with values from 0.0 to 1.0 and should return a tuple with ``channels`` elements. - :param channels: Passed to the constructor. + :param channels: The number of channels which should return callback. :param target_mode: Passed to the constructor. """ size1D, size2D, size3D = cls._check_size(size) - table = [] + if channels not in (3, 4): + raise ValueError("Only 3 or 4 output channels are supported") + + table = [0] * (size1D * size2D * size3D * channels) + idx_out = 0 for b in range(size3D): for g in range(size2D): for r in range(size1D): - table.append(callback( - r / float(size1D-1), - g / float(size2D-1), - b / float(size3D-1))) + table[idx_out:idx_out + channels] = callback( + r / (size1D-1), g / (size2D-1), b / (size3D-1)) + idx_out += channels - return cls((size1D, size2D, size3D), table, channels, target_mode) + return cls((size1D, size2D, size3D), table, channels=channels, + target_mode=target_mode, _copy_table=False) + + def alter(self, callback, channels=None, target_mode=None): + if channels not in (None, 3, 4): + raise ValueError("Only 3 or 4 output channels are supported") + ch_in = self.channels + ch_out = channels or ch_in + + table = [0] * (self.size[0] * self.size[1] * self.size[2] * ch_out) + idx_in = 0 + idx_out = 0 + for b in range(self.size[2]): + for g in range(self.size[1]): + for r in range(self.size[0]): + values = callback(*self.table[idx_in:idx_in + ch_in]) + table[idx_out:idx_out + ch_out] = values + idx_in += ch_in + idx_out += ch_out + + return type(self)(self.size, table, channels=ch_out, + target_mode=target_mode or self.mode, + _copy_table=False) def filter(self, image): from . import Image