From de843ef7567af9962903e18c4d69a99c7ceafce7 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 15 Jun 2017 17:22:32 +0100 Subject: [PATCH] Added test to reproduce bug #551 --- tests/test_connection.py | 58 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 57 insertions(+), 1 deletion(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index f61b099d..38303247 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -26,6 +26,7 @@ import os import sys import time import threading +import subprocess as sp from operator import attrgetter import psycopg2 @@ -33,7 +34,7 @@ import psycopg2.errorcodes from psycopg2 import extensions as ext from testutils import ( - unittest, decorate_all_tests, skip_if_no_superuser, + script_to_py3, unittest, decorate_all_tests, skip_if_no_superuser, skip_before_postgres, skip_after_postgres, skip_before_libpq, ConnectingTestCase, skip_if_tpc_disabled, skip_if_windows, slow) @@ -1516,6 +1517,61 @@ class PasswordLeakTestCase(ConnectingTestCase): "user=someone password=xxx host=localhost dbname=nosuch") +class SignalTestCase(ConnectingTestCase): + @slow + def test_bug_551_returning(self): + # Raise an exception trying to decode 'id' + self._test_bug_551(query=""" + INSERT INTO test551 (num) VALUES (%s) RETURNING id + """) + + @slow + def test_bug_551_no_returning(self): + # Raise an exception trying to decode 'INSERT 0 1' + self._test_bug_551(query=""" + INSERT INTO test551 (num) VALUES (%s) + """) + + def _test_bug_551(self, query): + script = ("""\ +import os +import sys +import time +import signal +import threading + +import psycopg2 + +def handle_sigabort(sig, frame): + sys.exit(1) + +def killer(): + time.sleep(0.5) + os.kill(os.getpid(), signal.SIGABRT) + +signal.signal(signal.SIGABRT, handle_sigabort) + +conn = psycopg2.connect(%(dsn)r) + +cur = conn.cursor() + +cur.execute("create table test551 (id serial, num varchar(50))") + +t = threading.Thread(target=killer) +t.daemon = True +t.start() + +while True: + cur.execute(%(query)r, ("Hello, world!",)) +""" % {'dsn': dsn, 'query': query}) + + proc = sp.Popen([sys.executable, '-c', script_to_py3(script)], + stdout=sp.PIPE, stderr=sp.PIPE) + (out, err) = proc.communicate() + self.assertEqual(1, proc.returncode) + self.assert_(not err, err) + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)