From e16ab0ad2ead98205143d06940f79a45300b4f33 Mon Sep 17 00:00:00 2001 From: Alexander Date: Sat, 12 Aug 2017 14:10:39 +0300 Subject: [PATCH] add tests, fix implementation --- PIL/Image.py | 9 ++++++--- Tests/test_image.py | 24 +++++++++++++++++++----- 2 files changed, 25 insertions(+), 8 deletions(-) diff --git a/PIL/Image.py b/PIL/Image.py index f23c4e32c..4555561a1 100644 --- a/PIL/Image.py +++ b/PIL/Image.py @@ -1975,9 +1975,12 @@ class Image(object): """ self.load() - if isStringType(channel) and len(channel) == 1: - if channel in self.im.mode: - channel = self.im.mode.index(channel) + if isStringType(channel): + try: + channel = self.getbands().index(channel) + except ValueError: + raise ValueError( + 'The image has no channel "{}"'.format(channel)) return self._new(self.im.getband(channel)) diff --git a/Tests/test_image.py b/Tests/test_image.py index 1f9c4d798..205c58e43 100644 --- a/Tests/test_image.py +++ b/Tests/test_image.py @@ -145,14 +145,28 @@ class TestImage(PillowTestCase): self.assertEqual(im.size[1], orig_size[1] + 2*ymargin) def test_getbands(self): - # Arrange + # Assert + self.assertEqual(hopper('RGB').getbands(), ('R', 'G', 'B')) + self.assertEqual(hopper('YCbCr').getbands(), ('Y', 'Cb', 'Cr')) + + def test_getchannel_wrong_params(self): im = hopper() - # Act - bands = im.getbands() + self.assertRaises(ValueError, im.getchannel, -1) + self.assertRaises(ValueError, im.getchannel, 3) + self.assertRaises(ValueError, im.getchannel, 'Z') + self.assertRaises(ValueError, im.getchannel, '1') - # Assert - self.assertEqual(bands, ('R', 'G', 'B')) + def test_getchannel(self): + im = hopper('YCbCr') + Y, Cb, Cr = im.split() + + self.assert_image_equal(Y, im.getchannel(0)) + self.assert_image_equal(Y, im.getchannel('Y')) + self.assert_image_equal(Cb, im.getchannel(1)) + self.assert_image_equal(Cb, im.getchannel('Cb')) + self.assert_image_equal(Cr, im.getchannel(2)) + self.assert_image_equal(Cr, im.getchannel('Cr')) def test_getbbox(self): # Arrange