diff --git a/tests/test_connection.py b/tests/test_connection.py index 75abcc02..09b2d9a0 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -22,14 +22,16 @@ # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. -import ctypes import gc import os import re -import subprocess as sp import sys -import threading import time +import ctypes +import shutil +import tempfile +import threading +import subprocess as sp from collections import deque from operator import attrgetter from weakref import ref @@ -359,10 +361,11 @@ class ConnectionTests(ConnectingTestCase): @slow def test_multiprocess_close(self): - script = ("""\ + dir = tempfile.mkdtemp() + try: + with open(os.path.join(dir, "mptest.py"), 'w') as f: + f.write("""\ import time -import threading -import multiprocessing import psycopg2 def thread(): @@ -374,16 +377,28 @@ def thread(): def process(): time.sleep(0.2) - -t = threading.Thread(target=thread, name='mythread') -t.start() -time.sleep(0.2) -multiprocessing.Process(target=process, name='myprocess').start() -t.join() """ % {'dsn': dsn}) - out = sp.check_output([sys.executable, '-c', script], stderr=sp.STDOUT) - self.assertEqual(out, b'', out.decode('ascii')) + script = ("""\ +import sys +sys.path.insert(0, %(dir)r) +import time +import threading +import multiprocessing +import mptest + +t = threading.Thread(target=mptest.thread, name='mythread') +t.start() +time.sleep(0.2) +multiprocessing.Process(target=mptest.process, name='myprocess').start() +t.join() +""" % {'dir': dir}) + + out = sp.check_output( + [sys.executable, '-c', script], stderr=sp.STDOUT) + self.assertEqual(out, b'', out) + finally: + shutil.rmtree(dir, ignore_errors=True) class ParseDsnTestCase(ConnectingTestCase):