Transparently store numpy arrays in ColorLut

This commit is contained in:
Alexander 2018-04-12 01:54:08 +03:00
parent b01ba0f50c
commit c8405ef706

View File

@ -19,6 +19,11 @@ from __future__ import division
import functools
try:
import numpy
except ImportError: # pragma: no cover
numpy = None
class Filter(object):
pass
@ -334,25 +339,40 @@ class Color3DLUT(MultibandFilter):
# Hidden flag `_copy_table=False` could be used to avoid extra copying
# of the table if the table is specially made for the constructor.
if kwargs.get('_copy_table', True):
table = list(table)
copy_table = kwargs.get('_copy_table', True)
items = size[0] * size[1] * size[2]
wrong_size = False
# Convert to a flat list
if table and isinstance(table[0], (list, tuple)):
table, raw_table = [], table
for pixel in raw_table:
if len(pixel) != channels:
raise ValueError("The elements of the table should have "
"a length of {}.".format(channels))
for color in pixel:
table.append(color)
if numpy and isinstance(table, numpy.ndarray):
if copy_table:
table = table.copy()
if len(table) != channels * size[0] * size[1] * size[2]:
if table.shape in [(items * channels,), (items, channels),
(size[2], size[1], size[0], channels)]:
table = table.reshape(items * channels)
else:
wrong_size = True
else:
if copy_table:
table = list(table)
# Convert to a flat list
if table and isinstance(table[0], (list, tuple)):
table, raw_table = [], table
for pixel in raw_table:
if len(pixel) != channels:
raise ValueError(
"The elements of the table should "
"have a length of {}.".format(channels))
table.extend(pixel)
if wrong_size or len(table) != items * channels:
raise ValueError(
"The table should have either channels * size**3 float items "
"or size**3 items of channels-sized tuples with floats. "
"Table size: {}x{}x{}. Table length: {}".format(
size[0], size[1], size[2], len(table)))
"Table should be: {}x{}x{}x{}. Actual length: {}".format(
channels, size[0], size[1], size[2], len(table)))
self.table = table
@staticmethod