From c8405ef7069c1aa902615d79bb899443bede686b Mon Sep 17 00:00:00 2001 From: Alexander Date: Thu, 12 Apr 2018 01:54:08 +0300 Subject: [PATCH] Transparently store numpy arrays in ColorLut --- src/PIL/ImageFilter.py | 48 ++++++++++++++++++++++++++++++------------ 1 file changed, 34 insertions(+), 14 deletions(-) diff --git a/src/PIL/ImageFilter.py b/src/PIL/ImageFilter.py index ff9348b21..c3c71b252 100644 --- a/src/PIL/ImageFilter.py +++ b/src/PIL/ImageFilter.py @@ -19,6 +19,11 @@ from __future__ import division import functools +try: + import numpy +except ImportError: # pragma: no cover + numpy = None + class Filter(object): pass @@ -334,25 +339,40 @@ class Color3DLUT(MultibandFilter): # Hidden flag `_copy_table=False` could be used to avoid extra copying # of the table if the table is specially made for the constructor. - if kwargs.get('_copy_table', True): - table = list(table) + copy_table = kwargs.get('_copy_table', True) + items = size[0] * size[1] * size[2] + wrong_size = False - # Convert to a flat list - if table and isinstance(table[0], (list, tuple)): - table, raw_table = [], table - for pixel in raw_table: - if len(pixel) != channels: - raise ValueError("The elements of the table should have " - "a length of {}.".format(channels)) - for color in pixel: - table.append(color) + if numpy and isinstance(table, numpy.ndarray): + if copy_table: + table = table.copy() - if len(table) != channels * size[0] * size[1] * size[2]: + if table.shape in [(items * channels,), (items, channels), + (size[2], size[1], size[0], channels)]: + table = table.reshape(items * channels) + else: + wrong_size = True + + else: + if copy_table: + table = list(table) + + # Convert to a flat list + if table and isinstance(table[0], (list, tuple)): + table, raw_table = [], table + for pixel in raw_table: + if len(pixel) != channels: + raise ValueError( + "The elements of the table should " + "have a length of {}.".format(channels)) + table.extend(pixel) + + if wrong_size or len(table) != items * channels: raise ValueError( "The table should have either channels * size**3 float items " "or size**3 items of channels-sized tuples with floats. " - "Table size: {}x{}x{}. Table length: {}".format( - size[0], size[1], size[2], len(table))) + "Table should be: {}x{}x{}x{}. Actual length: {}".format( + channels, size[0], size[1], size[2], len(table))) self.table = table @staticmethod