diff --git a/Tests/test_imageops.py b/Tests/test_imageops.py index f7e6c68bf..6aa1cf35e 100644 --- a/Tests/test_imageops.py +++ b/Tests/test_imageops.py @@ -156,23 +156,31 @@ def test_scale(): assert newimg.size == (25, 25) -def test_expand_palette(): - im = Image.open("Tests/images/p_16.tga") - im_expanded = ImageOps.expand(im, 10, (255, 0, 0)) +@pytest.mark.parametrize("border", (10, (1, 2, 3, 4))) +def test_expand_palette(border): + with Image.open("Tests/images/p_16.tga") as im: + im_expanded = ImageOps.expand(im, border, (255, 0, 0)) - px = im_expanded.convert("RGB").load() - for b in range(10): + if isinstance(border, int): + left = top = right = bottom = border + else: + left, top, right, bottom = border + px = im_expanded.convert("RGB").load() for x in range(im_expanded.width): - assert px[x, b] == (255, 0, 0) - assert px[x, im_expanded.height - 1 - b] == (255, 0, 0) + for b in range(top): + assert px[x, b] == (255, 0, 0) + for b in range(bottom): + assert px[x, im_expanded.height - 1 - b] == (255, 0, 0) for y in range(im_expanded.height): - assert px[b, x] == (255, 0, 0) - assert px[b, im_expanded.width - 1 - b] == (255, 0, 0) + for b in range(left): + assert px[b, y] == (255, 0, 0) + for b in range(right): + assert px[im_expanded.width - 1 - b, y] == (255, 0, 0) - im_cropped = im_expanded.crop( - (10, 10, im_expanded.width - 10, im_expanded.height - 10) - ) - assert_image_equal(im_cropped, im) + im_cropped = im_expanded.crop( + (left, top, im_expanded.width - right, im_expanded.height - bottom) + ) + assert_image_equal(im_cropped, im) def test_colorize_2color(): diff --git a/src/PIL/ImageOps.py b/src/PIL/ImageOps.py index e06a7eaca..f0c932d33 100644 --- a/src/PIL/ImageOps.py +++ b/src/PIL/ImageOps.py @@ -21,7 +21,7 @@ import functools import operator import re -from . import Image, ImageDraw +from . import Image # # helpers @@ -395,15 +395,16 @@ def expand(image, border=0, fill=0): height = top + image.size[1] + bottom color = _color(fill, image.mode) if image.mode == "P" and image.palette: - out = Image.new(image.mode, (width, height)) - out.putpalette(image.palette) - out.paste(image, (left, top)) - - draw = ImageDraw.Draw(out) - draw.rectangle((0, 0, width - 1, height - 1), outline=color, width=border) + image.load() + palette = image.palette.copy() + if isinstance(color, tuple): + color = palette.getcolor(color) else: - out = Image.new(image.mode, (width, height), color) - out.paste(image, (left, top)) + palette = None + out = Image.new(image.mode, (width, height), color) + if palette: + out.putpalette(palette.palette) + out.paste(image, (left, top)) return out