Do not use sys.stdout in PSDraw

This commit is contained in:
Andrew Murray 2024-08-19 08:47:35 +10:00
parent f8d3e36176
commit d1d567bb59
2 changed files with 7 additions and 15 deletions

View File

@ -5,8 +5,6 @@ import sys
from io import BytesIO from io import BytesIO
from pathlib import Path from pathlib import Path
import pytest
from PIL import Image, PSDraw from PIL import Image, PSDraw
@ -49,15 +47,14 @@ def test_draw_postscript(tmp_path: Path) -> None:
assert os.path.getsize(tempfile) > 0 assert os.path.getsize(tempfile) > 0
@pytest.mark.parametrize("buffer", (True, False)) def test_stdout() -> None:
def test_stdout(buffer: bool) -> None:
# Temporarily redirect stdout # Temporarily redirect stdout
old_stdout = sys.stdout old_stdout = sys.stdout
class MyStdOut: class MyStdOut:
buffer = BytesIO() buffer = BytesIO()
mystdout: MyStdOut | BytesIO = MyStdOut() if buffer else BytesIO() mystdout = MyStdOut()
sys.stdout = mystdout sys.stdout = mystdout
@ -67,6 +64,4 @@ def test_stdout(buffer: bool) -> None:
# Reset stdout # Reset stdout
sys.stdout = old_stdout sys.stdout = old_stdout
if isinstance(mystdout, MyStdOut): assert mystdout.buffer.getvalue() != b""
mystdout = mystdout.buffer
assert mystdout.getvalue() != b""

View File

@ -17,7 +17,7 @@
from __future__ import annotations from __future__ import annotations
import sys import sys
from typing import TYPE_CHECKING from typing import IO, TYPE_CHECKING
from . import EpsImagePlugin from . import EpsImagePlugin
@ -28,15 +28,12 @@ from . import EpsImagePlugin
class PSDraw: class PSDraw:
""" """
Sets up printing to the given file. If ``fp`` is omitted, Sets up printing to the given file. If ``fp`` is omitted,
``sys.stdout.buffer`` or ``sys.stdout`` is assumed. ``sys.stdout.buffer`` is assumed.
""" """
def __init__(self, fp=None): def __init__(self, fp: IO[bytes] | None = None) -> None:
if not fp: if not fp:
try:
fp = sys.stdout.buffer fp = sys.stdout.buffer
except AttributeError:
fp = sys.stdout
self.fp = fp self.fp = fp
def begin_document(self, id: str | None = None) -> None: def begin_document(self, id: str | None = None) -> None: