Merge pull request #3099 from uploadcare/lut-numpy

NumPy support for LUTs
This commit is contained in:
Hugo 2018-07-01 13:18:03 +03:00 committed by GitHub
commit 2b09e7fa6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 202 additions and 41 deletions

View File

@ -15,6 +15,7 @@ pip install -U pytest
pip install -U pytest-cov pip install -U pytest-cov
pip install pyroma pip install pyroma
pip install test-image-results pip install test-image-results
pip install numpy
# docs only on Python 2.7 # docs only on Python 2.7
if [ "$TRAVIS_PYTHON_VERSION" == "2.7" ]; then pip install -r requirements.txt ; fi if [ "$TRAVIS_PYTHON_VERSION" == "2.7" ]; then pip install -r requirements.txt ; fi

View File

@ -5,6 +5,11 @@ from array import array
from PIL import Image, ImageFilter from PIL import Image, ImageFilter
from helper import unittest, PillowTestCase from helper import unittest, PillowTestCase
try:
import numpy
except ImportError:
numpy = None
class TestColorLut3DCoreAPI(PillowTestCase): class TestColorLut3DCoreAPI(PillowTestCase):
def generate_identity_table(self, channels, size): def generate_identity_table(self, channels, size):
@ -279,6 +284,80 @@ class TestColorLut3DFilter(PillowTestCase):
lut = ImageFilter.Color3DLUT((2, 2, 2), [(0, 1, 2, 3)] * 8, lut = ImageFilter.Color3DLUT((2, 2, 2), [(0, 1, 2, 3)] * 8,
channels=4) channels=4)
@unittest.skipIf(numpy is None, "Numpy is not installed")
def test_numpy_sources(self):
table = numpy.ones((5, 6, 7, 3), dtype=numpy.float16)
with self.assertRaisesRegex(ValueError, "should have either channels"):
lut = ImageFilter.Color3DLUT((5, 6, 7), table)
table = numpy.ones((7, 6, 5, 3), dtype=numpy.float16)
lut = ImageFilter.Color3DLUT((5, 6, 7), table)
self.assertIsInstance(lut.table, numpy.ndarray)
self.assertEqual(lut.table.dtype, table.dtype)
self.assertEqual(lut.table.shape, (table.size,))
table = numpy.ones((7 * 6 * 5, 3), dtype=numpy.float16)
lut = ImageFilter.Color3DLUT((5, 6, 7), table)
self.assertEqual(lut.table.shape, (table.size,))
table = numpy.ones((7 * 6 * 5 * 3), dtype=numpy.float16)
lut = ImageFilter.Color3DLUT((5, 6, 7), table)
self.assertEqual(lut.table.shape, (table.size,))
# Check application
Image.new('RGB', (10, 10), 0).filter(lut)
# Check copy
table[0] = 33
self.assertEqual(lut.table[0], 1)
# Check not copy
table = numpy.ones((7 * 6 * 5 * 3), dtype=numpy.float16)
lut = ImageFilter.Color3DLUT((5, 6, 7), table, _copy_table=False)
table[0] = 33
self.assertEqual(lut.table[0], 33)
@unittest.skipIf(numpy is None, "Numpy is not installed")
def test_numpy_formats(self):
g = Image.linear_gradient('L')
im = Image.merge('RGB', [g, g.transpose(Image.ROTATE_90),
g.transpose(Image.ROTATE_180)])
lut = ImageFilter.Color3DLUT.generate((7, 9, 11),
lambda r, g, b: (r, g, b))
lut.table = numpy.array(lut.table, dtype=numpy.float32)[:-1]
with self.assertRaisesRegex(ValueError, "should have table_channels"):
im.filter(lut)
lut = ImageFilter.Color3DLUT.generate((7, 9, 11),
lambda r, g, b: (r, g, b))
lut.table = (numpy.array(lut.table, dtype=numpy.float32)
.reshape((7 * 9 * 11), 3))
with self.assertRaisesRegex(ValueError, "should have table_channels"):
im.filter(lut)
lut = ImageFilter.Color3DLUT.generate((7, 9, 11),
lambda r, g, b: (r, g, b))
lut.table = numpy.array(lut.table, dtype=numpy.float16)
self.assert_image_equal(im, im.filter(lut))
lut = ImageFilter.Color3DLUT.generate((7, 9, 11),
lambda r, g, b: (r, g, b))
lut.table = numpy.array(lut.table, dtype=numpy.float32)
self.assert_image_equal(im, im.filter(lut))
lut = ImageFilter.Color3DLUT.generate((7, 9, 11),
lambda r, g, b: (r, g, b))
lut.table = numpy.array(lut.table, dtype=numpy.float64)
self.assert_image_equal(im, im.filter(lut))
lut = ImageFilter.Color3DLUT.generate((7, 9, 11),
lambda r, g, b: (r, g, b))
lut.table = numpy.array(lut.table, dtype=numpy.int32)
im.filter(lut)
lut.table = numpy.array(lut.table, dtype=numpy.int8)
im.filter(lut)
def test_repr(self): def test_repr(self):
lut = ImageFilter.Color3DLUT(2, [0, 1, 2] * 8) lut = ImageFilter.Color3DLUT(2, [0, 1, 2] * 8)
self.assertEqual(repr(lut), self.assertEqual(repr(lut),

View File

@ -4,28 +4,16 @@ from helper import PillowTestCase, hopper, unittest
from PIL import Image from PIL import Image
try: try:
import site
import numpy import numpy
assert site # silence warning
assert numpy # silence warning
except ImportError: except ImportError:
# Skip via setUp() numpy = None
pass
TEST_IMAGE_SIZE = (10, 10) TEST_IMAGE_SIZE = (10, 10)
@unittest.skipIf(numpy is None, "Numpy is not installed")
class TestNumpy(PillowTestCase): class TestNumpy(PillowTestCase):
def setUp(self):
try:
import site
import numpy
assert site # silence warning
assert numpy # silence warning
except ImportError:
self.skipTest("ImportError")
def test_numpy_to_image(self): def test_numpy_to_image(self):
def to_image(dtype, bands=1, boolean=0): def to_image(dtype, bands=1, boolean=0):

View File

@ -19,6 +19,11 @@ from __future__ import division
import functools import functools
try:
import numpy
except ImportError: # pragma: no cover
numpy = None
class Filter(object): class Filter(object):
pass pass
@ -310,6 +315,8 @@ class Color3DLUT(MultibandFilter):
This method allows you to apply almost any color transformation This method allows you to apply almost any color transformation
in constant time by using pre-calculated decimated tables. in constant time by using pre-calculated decimated tables.
.. versionadded:: 5.2.0
:param size: Size of the table. One int or tuple of (int, int, int). :param size: Size of the table. One int or tuple of (int, int, int).
Minimal size in any dimension is 2, maximum is 65. Minimal size in any dimension is 2, maximum is 65.
:param table: Flat lookup table. A list of ``channels * size**3`` :param table: Flat lookup table. A list of ``channels * size**3``
@ -334,25 +341,40 @@ class Color3DLUT(MultibandFilter):
# Hidden flag `_copy_table=False` could be used to avoid extra copying # Hidden flag `_copy_table=False` could be used to avoid extra copying
# of the table if the table is specially made for the constructor. # of the table if the table is specially made for the constructor.
if kwargs.get('_copy_table', True): copy_table = kwargs.get('_copy_table', True)
table = list(table) items = size[0] * size[1] * size[2]
wrong_size = False
# Convert to a flat list if numpy and isinstance(table, numpy.ndarray):
if table and isinstance(table[0], (list, tuple)): if copy_table:
table, raw_table = [], table table = table.copy()
for pixel in raw_table:
if len(pixel) != channels:
raise ValueError("The elements of the table should have "
"a length of {}.".format(channels))
for color in pixel:
table.append(color)
if len(table) != channels * size[0] * size[1] * size[2]: if table.shape in [(items * channels,), (items, channels),
(size[2], size[1], size[0], channels)]:
table = table.reshape(items * channels)
else:
wrong_size = True
else:
if copy_table:
table = list(table)
# Convert to a flat list
if table and isinstance(table[0], (list, tuple)):
table, raw_table = [], table
for pixel in raw_table:
if len(pixel) != channels:
raise ValueError(
"The elements of the table should "
"have a length of {}.".format(channels))
table.extend(pixel)
if wrong_size or len(table) != items * channels:
raise ValueError( raise ValueError(
"The table should have either channels * size**3 float items " "The table should have either channels * size**3 float items "
"or size**3 items of channels-sized tuples with floats. " "or size**3 items of channels-sized tuples with floats. "
"Table size: {}x{}x{}. Table length: {}".format( "Table should be: {}x{}x{}x{}. Actual length: {}".format(
size[0], size[1], size[2], len(table))) channels, size[0], size[1], size[2], len(table)))
self.table = table self.table = table
@staticmethod @staticmethod

View File

@ -354,6 +354,7 @@ getbands(const char* mode)
#define TYPE_UINT8 (0x100|sizeof(UINT8)) #define TYPE_UINT8 (0x100|sizeof(UINT8))
#define TYPE_INT32 (0x200|sizeof(INT32)) #define TYPE_INT32 (0x200|sizeof(INT32))
#define TYPE_FLOAT16 (0x500|sizeof(FLOAT16))
#define TYPE_FLOAT32 (0x300|sizeof(FLOAT32)) #define TYPE_FLOAT32 (0x300|sizeof(FLOAT32))
#define TYPE_DOUBLE (0x400|sizeof(double)) #define TYPE_DOUBLE (0x400|sizeof(double))
@ -437,6 +438,30 @@ getlist(PyObject* arg, Py_ssize_t* length, const char* wrong_length, int type)
return list; return list;
} }
FLOAT32
float16tofloat32(const FLOAT16 in) {
UINT32 t1;
UINT32 t2;
UINT32 t3;
FLOAT32 out[1] = {0};
t1 = in & 0x7fff; // Non-sign bits
t2 = in & 0x8000; // Sign bit
t3 = in & 0x7c00; // Exponent
t1 <<= 13; // Align mantissa on MSB
t2 <<= 16; // Shift sign bit into position
t1 += 0x38000000; // Adjust bias
t1 = (t3 == 0 ? 0 : t1); // Denormals-as-zero
t1 |= t2; // Re-insert sign bit
memcpy(out, &t1, 4);
return out[0];
}
static inline PyObject* static inline PyObject*
getpixel(Imaging im, ImagingAccess access, int x, int y) getpixel(Imaging im, ImagingAccess access, int x, int y)
{ {
@ -700,22 +725,54 @@ _blend(ImagingObject* self, PyObject* args)
/* METHODS */ /* METHODS */
/* -------------------------------------------------------------------- */ /* -------------------------------------------------------------------- */
static INT16* static INT16*
_prepare_lut_table(PyObject* table, Py_ssize_t table_size) _prepare_lut_table(PyObject* table, Py_ssize_t table_size)
{ {
int i; int i;
FLOAT32* table_data; Py_buffer buffer_info;
INT32 data_type = TYPE_FLOAT32;
float item = 0;
void* table_data = NULL;
int free_table_data = 0;
INT16* prepared; INT16* prepared;
/* NOTE: This value should be the same as in ColorLUT.c */ /* NOTE: This value should be the same as in ColorLUT.c */
#define PRECISION_BITS (16 - 8 - 2) #define PRECISION_BITS (16 - 8 - 2)
table_data = (FLOAT32*) getlist(table, &table_size, const char* wrong_size = ("The table should have table_channels * "
"The table should have table_channels * " "size1D * size2D * size3D float items.");
"size1D * size2D * size3D float items.", TYPE_FLOAT32);
if (PyObject_CheckBuffer(table)) {
if ( ! PyObject_GetBuffer(table, &buffer_info,
PyBUF_CONTIG_RO | PyBUF_FORMAT)) {
if (buffer_info.ndim == 1 && buffer_info.shape[0] == table_size) {
if (strlen(buffer_info.format) == 1) {
switch (buffer_info.format[0]) {
case 'e':
data_type = TYPE_FLOAT16;
table_data = buffer_info.buf;
break;
case 'f':
data_type = TYPE_FLOAT32;
table_data = buffer_info.buf;
break;
case 'd':
data_type = TYPE_DOUBLE;
table_data = buffer_info.buf;
break;
}
}
}
PyBuffer_Release(&buffer_info);
}
}
if ( ! table_data) { if ( ! table_data) {
return NULL; free_table_data = 1;
table_data = getlist(table, &table_size, wrong_size, TYPE_FLOAT32);
if ( ! table_data) {
return NULL;
}
} }
/* malloc check ok, max is 2 * 4 * 65**3 = 2197000 */ /* malloc check ok, max is 2 * 4 * 65**3 = 2197000 */
@ -726,25 +783,38 @@ _prepare_lut_table(PyObject* table, Py_ssize_t table_size)
} }
for (i = 0; i < table_size; i++) { for (i = 0; i < table_size; i++) {
switch (data_type) {
case TYPE_FLOAT16:
item = float16tofloat32(((FLOAT16*) table_data)[i]);
break;
case TYPE_FLOAT32:
item = ((FLOAT32*) table_data)[i];
break;
case TYPE_DOUBLE:
item = ((double*) table_data)[i];
break;
}
/* Max value for INT16 */ /* Max value for INT16 */
if (table_data[i] >= (0x7fff - 0.5) / (255 << PRECISION_BITS)) { if (item >= (0x7fff - 0.5) / (255 << PRECISION_BITS)) {
prepared[i] = 0x7fff; prepared[i] = 0x7fff;
continue; continue;
} }
/* Min value for INT16 */ /* Min value for INT16 */
if (table_data[i] <= (-0x8000 + 0.5) / (255 << PRECISION_BITS)) { if (item <= (-0x8000 + 0.5) / (255 << PRECISION_BITS)) {
prepared[i] = -0x8000; prepared[i] = -0x8000;
continue; continue;
} }
if (table_data[i] < 0) { if (item < 0) {
prepared[i] = table_data[i] * (255 << PRECISION_BITS) - 0.5; prepared[i] = item * (255 << PRECISION_BITS) - 0.5;
} else { } else {
prepared[i] = table_data[i] * (255 << PRECISION_BITS) + 0.5; prepared[i] = item * (255 << PRECISION_BITS) + 0.5;
} }
} }
#undef PRECISION_BITS #undef PRECISION_BITS
free(table_data); if (free_table_data) {
free(table_data);
}
return prepared; return prepared;
} }

View File

@ -71,6 +71,7 @@
#endif #endif
/* assume IEEE; tweak if necessary (patches are welcome) */ /* assume IEEE; tweak if necessary (patches are welcome) */
#define FLOAT16 UINT16
#define FLOAT32 float #define FLOAT32 float
#define FLOAT64 double #define FLOAT64 double