Transparently store numpy arrays in ColorLut

This commit is contained in:
Alexander 2018-04-12 01:54:08 +03:00
parent b01ba0f50c
commit c8405ef706

View File

@ -19,6 +19,11 @@ from __future__ import division
import functools import functools
try:
import numpy
except ImportError: # pragma: no cover
numpy = None
class Filter(object): class Filter(object):
pass pass
@ -334,25 +339,40 @@ class Color3DLUT(MultibandFilter):
# Hidden flag `_copy_table=False` could be used to avoid extra copying # Hidden flag `_copy_table=False` could be used to avoid extra copying
# of the table if the table is specially made for the constructor. # of the table if the table is specially made for the constructor.
if kwargs.get('_copy_table', True): copy_table = kwargs.get('_copy_table', True)
table = list(table) items = size[0] * size[1] * size[2]
wrong_size = False
# Convert to a flat list if numpy and isinstance(table, numpy.ndarray):
if table and isinstance(table[0], (list, tuple)): if copy_table:
table, raw_table = [], table table = table.copy()
for pixel in raw_table:
if len(pixel) != channels:
raise ValueError("The elements of the table should have "
"a length of {}.".format(channels))
for color in pixel:
table.append(color)
if len(table) != channels * size[0] * size[1] * size[2]: if table.shape in [(items * channels,), (items, channels),
(size[2], size[1], size[0], channels)]:
table = table.reshape(items * channels)
else:
wrong_size = True
else:
if copy_table:
table = list(table)
# Convert to a flat list
if table and isinstance(table[0], (list, tuple)):
table, raw_table = [], table
for pixel in raw_table:
if len(pixel) != channels:
raise ValueError(
"The elements of the table should "
"have a length of {}.".format(channels))
table.extend(pixel)
if wrong_size or len(table) != items * channels:
raise ValueError( raise ValueError(
"The table should have either channels * size**3 float items " "The table should have either channels * size**3 float items "
"or size**3 items of channels-sized tuples with floats. " "or size**3 items of channels-sized tuples with floats. "
"Table size: {}x{}x{}. Table length: {}".format( "Table should be: {}x{}x{}x{}. Actual length: {}".format(
size[0], size[1], size[2], len(table))) channels, size[0], size[1], size[2], len(table)))
self.table = table self.table = table
@staticmethod @staticmethod