Allow for sys.stdout.buffer to be missing

This commit is contained in:
Andrew Murray 2021-05-03 18:07:05 +10:00
parent ce3d69baf9
commit 0f68e63793
6 changed files with 41 additions and 15 deletions

View File

@ -713,13 +713,20 @@ class TestFilePng:
with pytest.raises(EOFError): with pytest.raises(EOFError):
im.seek(1) im.seek(1)
def test_save_stdout(self): @pytest.mark.parametrize("buffer", (True, False))
def test_save_stdout(self, buffer):
old_stdout = sys.stdout.buffer old_stdout = sys.stdout.buffer
if buffer:
class MyStdOut: class MyStdOut:
buffer = BytesIO() buffer = BytesIO()
sys.stdout = mystdout = MyStdOut() mystdout = MyStdOut()
else:
mystdout = BytesIO()
sys.stdout = mystdout
with Image.open(TEST_PNG_FILE) as im: with Image.open(TEST_PNG_FILE) as im:
im.save(sys.stdout, "PNG") im.save(sys.stdout, "PNG")
@ -727,7 +734,9 @@ class TestFilePng:
# Reset stdout # Reset stdout
sys.stdout = old_stdout sys.stdout = old_stdout
reloaded = Image.open(mystdout.buffer) if buffer:
mystdout = mystdout.buffer
reloaded = Image.open(mystdout)
assert_image_equal_tofile(reloaded, TEST_PNG_FILE) assert_image_equal_tofile(reloaded, TEST_PNG_FILE)

View File

@ -2,6 +2,8 @@ import os
import sys import sys
from io import BytesIO from io import BytesIO
import pytest
from PIL import Image, PSDraw from PIL import Image, PSDraw
@ -44,14 +46,21 @@ def test_draw_postscript(tmp_path):
assert os.path.getsize(tempfile) > 0 assert os.path.getsize(tempfile) > 0
def test_stdout(): @pytest.mark.parametrize("buffer", (True, False))
def test_stdout(buffer):
# Temporarily redirect stdout # Temporarily redirect stdout
old_stdout = sys.stdout.buffer old_stdout = sys.stdout.buffer
if buffer:
class MyStdOut: class MyStdOut:
buffer = BytesIO() buffer = BytesIO()
sys.stdout = mystdout = MyStdOut() mystdout = MyStdOut()
else:
mystdout = BytesIO()
sys.stdout = mystdout
ps = PSDraw.PSDraw() ps = PSDraw.PSDraw()
_create_document(ps) _create_document(ps)
@ -59,4 +68,6 @@ def test_stdout():
# Reset stdout # Reset stdout
sys.stdout = old_stdout sys.stdout = old_stdout
assert mystdout.buffer.getvalue() != b"" if buffer:
mystdout = mystdout.buffer
assert mystdout.getvalue() != b""

View File

@ -424,7 +424,7 @@ Drawing PostScript
title = "hopper" title = "hopper"
box = (1*72, 2*72, 7*72, 10*72) # in points box = (1*72, 2*72, 7*72, 10*72) # in points
ps = PSDraw.PSDraw() # default is sys.stdout.buffer ps = PSDraw.PSDraw() # default is sys.stdout or sys.stdout.buffer
ps.begin_document(title) ps.begin_document(title)
# draw the image (75 dpi) # draw the image (75 dpi)

View File

@ -2136,7 +2136,10 @@ class Image:
filename = str(fp) filename = str(fp)
open_fp = True open_fp = True
elif fp == sys.stdout: elif fp == sys.stdout:
try:
fp = sys.stdout.buffer fp = sys.stdout.buffer
except AttributeError:
pass
if not filename and hasattr(fp, "name") and isPath(fp.name): if not filename and hasattr(fp, "name") and isPath(fp.name):
# only set the name for metadata purposes # only set the name for metadata purposes
filename = fp.name filename = fp.name

View File

@ -493,7 +493,7 @@ def _save(im, fp, tile, bufsize=0):
# But, it would need at least the image size in most cases. RawEncode is # But, it would need at least the image size in most cases. RawEncode is
# a tricky case. # a tricky case.
bufsize = max(MAXBLOCK, bufsize, im.size[0] * 4) # see RawEncode.c bufsize = max(MAXBLOCK, bufsize, im.size[0] * 4) # see RawEncode.c
if fp == sys.stdout.buffer: if fp == sys.stdout or (hasattr(sys.stdout, "buffer") and fp == sys.stdout.buffer):
fp.flush() fp.flush()
return return
try: try:

View File

@ -26,12 +26,15 @@ from . import EpsImagePlugin
class PSDraw: class PSDraw:
""" """
Sets up printing to the given file. If ``fp`` is omitted, Sets up printing to the given file. If ``fp`` is omitted,
``sys.stdout.buffer`` is assumed. ``sys.stdout.buffer`` or ``sys.stdout`` is assumed.
""" """
def __init__(self, fp=None): def __init__(self, fp=None):
if not fp: if not fp:
try:
fp = sys.stdout.buffer fp = sys.stdout.buffer
except AttributeError:
fp = sys.stdout
self.fp = fp self.fp = fp
def begin_document(self, id=None): def begin_document(self, id=None):