diff --git a/Tests/test_file_tar.py b/Tests/test_file_tar.py index 9e02ab1a5..bbf2a803b 100644 --- a/Tests/test_file_tar.py +++ b/Tests/test_file_tar.py @@ -15,21 +15,25 @@ class TestFileTar(PillowTestCase): self.skipTest("neither jpeg nor zip support available") def test_sanity(self): - if "zip_decoder" in codecs: - tar = TarIO.TarIO(TEST_TAR_FILE, 'hopper.png') - im = Image.open(tar) - im.load() - self.assertEqual(im.mode, "RGB") - self.assertEqual(im.size, (128, 128)) - self.assertEqual(im.format, "PNG") + for codec, test_path, format in [ + ['zip_decoder', 'hopper.png', 'PNG'], + ['jpeg_decoder', 'hopper.jpg', 'JPEG'] + ]: + if codec in codecs: + tar = TarIO.TarIO(TEST_TAR_FILE, test_path) + im = Image.open(tar) + im.load() + self.assertEqual(im.mode, "RGB") + self.assertEqual(im.size, (128, 128)) + self.assertEqual(im.format, format) - if "jpeg_decoder" in codecs: - tar = TarIO.TarIO(TEST_TAR_FILE, 'hopper.jpg') - im = Image.open(tar) - im.load() - self.assertEqual(im.mode, "RGB") - self.assertEqual(im.size, (128, 128)) - self.assertEqual(im.format, "JPEG") + def test_close(self): + tar = TarIO.TarIO(TEST_TAR_FILE, 'hopper.jpg') + tar.close() + + def test_contextmanager(self): + with TarIO.TarIO(TEST_TAR_FILE, 'hopper.jpg') as tar: + pass if __name__ == '__main__': diff --git a/src/PIL/TarIO.py b/src/PIL/TarIO.py index 0e949ff88..7c09685f7 100644 --- a/src/PIL/TarIO.py +++ b/src/PIL/TarIO.py @@ -14,6 +14,7 @@ # See the README file for information on usage and redistribution. # +import sys from . import ContainerIO @@ -30,11 +31,11 @@ class TarIO(ContainerIO.ContainerIO): :param tarfile: Name of TAR file. :param file: Name of member file. """ - fh = open(tarfile, "rb") + self.fh = open(tarfile, "rb") while True: - s = fh.read(512) + s = self.fh.read(512) if len(s) != 512: raise IOError("unexpected end of tar file") @@ -50,7 +51,21 @@ class TarIO(ContainerIO.ContainerIO): if file == name: break - fh.seek((size + 511) & (~511), 1) + self.fh.seek((size + 511) & (~511), 1) # Open region - ContainerIO.ContainerIO.__init__(self, fh, fh.tell(), size) + ContainerIO.ContainerIO.__init__(self, self.fh, self.fh.tell(), size) + + # Context Manager Support + def __enter__(self): + return self + + def __exit__(self, *args): + self.close() + + if sys.version_info.major >= 3: + def __del__(self): + self.close() + + def close(self): + self.fh.close()