diff --git a/src/PIL/VtfImagePlugin.py b/src/PIL/VtfImagePlugin.py index 76d8a002e..580c315c6 100644 --- a/src/PIL/VtfImagePlugin.py +++ b/src/PIL/VtfImagePlugin.py @@ -188,27 +188,44 @@ def _write_image(fp: BufferedIOBase, im: Image.Image, pixel_format: VtfPF): if pixel_format == VtfPF.DXT1: encoder = 'bcn' encoder_args = (1, "DXT1") + im = im.convert('RGB') + elif pixel_format == VtfPF.DXT1_ONEBITALPHA: + encoder = 'bcn' + encoder_args = (1, "DXT1A") + im = im.convert('RGBA') elif pixel_format == VtfPF.DXT3: encoder = 'bcn' encoder_args = (3, "DXT3") + im = im.convert('RGBA') elif pixel_format == VtfPF.DXT5: encoder = 'bcn' encoder_args = (5, "DXT5") + im = im.convert('RGBA') elif pixel_format == VtfPF.RGB888: encoder = 'raw' encoder_args = ("RGB", 0, 0) + im = im.convert('RGB') elif pixel_format == VtfPF.BGR888: encoder = 'raw' encoder_args = ("BGR", 0, 0) + im = im.convert('RGB') elif pixel_format == VtfPF.RGBA8888: encoder = 'raw' encoder_args = ("RGBA", 0, 0) + im = im.convert('RGBA') + elif pixel_format == VtfPF.A8: + encoder = 'raw' + encoder_args = ("L", 0, 0) + *_, a = im.split() + im = Image.merge('L', (a,)) elif pixel_format == VtfPF.I8: encoder = 'raw' encoder_args = ("L", 0, 0) + im = im.convert('L') elif pixel_format == VtfPF.IA88: encoder = 'raw' encoder_args = ("LA", 0, 0) + im = im.convert('LA') elif pixel_format == VtfPF.UV88: encoder = 'raw' r, g, *_ = im.split() @@ -316,22 +333,17 @@ def _save(im, fp, filename): generate_mips = encoderinfo.get('generate_mips', True) flags = CompiledVtfFlags(0) - if pixel_format in RGBA_FORMATS: - im = im.convert('RGBA') - if pixel_format in RGB_FORMATS: - im = im.convert('RGB') - if pixel_format in L_FORMATS: - im = im.convert('L') - if pixel_format in LA_FORMATS: - im = im.convert('LA') - if "A" in im.mode: - if pixel_format == VtfPF.DXT1_ONEBITALPHA: - flags |= CompiledVtfFlags.ONEBITALPHA - elif pixel_format == VtfPF.DXT1: - im = im.convert("RGB") - else: - flags |= CompiledVtfFlags.EIGHTBITALPHA + if pixel_format == VtfPF.DXT1_ONEBITALPHA: + flags |= CompiledVtfFlags.ONEBITALPHA + elif pixel_format == VtfPF.A8: + flags |= CompiledVtfFlags.EIGHTBITALPHA + elif pixel_format in RGBA_FORMATS + LA_FORMATS: + flags |= CompiledVtfFlags.EIGHTBITALPHA + elif pixel_format in RGB_FORMATS + L_FORMATS: + pass + else: + raise VTFException('Unhandled case') im = im.resize((_closest_power(im.width), _closest_power(im.height))) width, height = im.size