From 47f5e97759879543edd8ee8ad9032ef67ec0567a Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Wed, 5 Apr 2017 14:40:12 +0100 Subject: [PATCH] Added test to verify #410 The 'unknown error' happens on query. --- tests/test_green.py | 68 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/tests/test_green.py b/tests/test_green.py index 6d1571d4..b8fe4e13 100755 --- a/tests/test_green.py +++ b/tests/test_green.py @@ -112,6 +112,74 @@ class GreenTestCase(ConnectingTestCase): self.assertEqual(curs.fetchone()[0], 1) +class CallbackErrorTestCase(ConnectingTestCase): + def setUp(self): + self._cb = psycopg2.extensions.get_wait_callback() + psycopg2.extensions.set_wait_callback(self.crappy_callback) + ConnectingTestCase.setUp(self) + self.to_error = None + + def tearDown(self): + ConnectingTestCase.tearDown(self) + psycopg2.extensions.set_wait_callback(self._cb) + + def crappy_callback(self, conn): + """green callback failing after `self.to_error` time it is called""" + import select + from psycopg2.extensions import POLL_OK, POLL_READ, POLL_WRITE + + while 1: + if self.to_error is not None: + self.to_error -= 1 + if self.to_error <= 0: + raise ZeroDivisionError("I accidentally the connection") + try: + state = conn.poll() + if state == POLL_OK: + break + elif state == POLL_READ: + select.select([conn.fileno()], [], []) + elif state == POLL_WRITE: + select.select([], [conn.fileno()], []) + else: + raise conn.OperationalError("bad state from poll: %s" % state) + except KeyboardInterrupt: + conn.cancel() + # the loop will be broken by a server error + continue + + def test_errors_on_connection(self): + # Test error propagation in the different stages of the connection + for i in range(100): + self.to_error = i + try: + self.connect() + except ZeroDivisionError: + pass + else: + # We managed to connect + return + + self.fail("you should have had a success or an error by now") + + def test_errors_on_query(self): + for i in range(100): + self.to_error = None + cnn = self.connect() + cur = cnn.cursor() + self.to_error = i + try: + cur.execute("select 1") + cur.fetchone() + except ZeroDivisionError: + pass + else: + # The query completed + return + + self.fail("you should have had a success or an error by now") + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)