From 12645db754f9f665f8959a6072628ddb094f9abf Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 3 Dec 2012 03:37:47 +0000 Subject: [PATCH] Make sure to call subclasses methods on context exit --- psycopg/connection_type.c | 8 ++++-- psycopg/cursor_type.c | 4 ++- tests/test_with.py | 55 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 64 insertions(+), 3 deletions(-) diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index c1d6176e..b5fd0789 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -417,9 +417,13 @@ psyco_conn_exit(connectionObject *self, PyObject *args) } if (type == Py_None) { - if (!(tmp = psyco_conn_commit(self))) { goto exit; } + if (!(tmp = PyObject_CallMethod((PyObject *)self, "commit", ""))) { + goto exit; + } } else { - if (!(tmp = psyco_conn_rollback(self))) { goto exit; } + if (!(tmp = PyObject_CallMethod((PyObject *)self, "rollback", ""))) { + goto exit; + } } /* success (of the commit or rollback, there may have been an exception in diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 5e17bff1..9570f914 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1222,7 +1222,9 @@ psyco_curs_exit(cursorObject *self, PyObject *args) /* don't care about the arguments here: don't need to parse them */ - if (!(tmp = psyco_curs_close(self))) { goto exit; } + if (!(tmp = PyObject_CallMethod((PyObject *)self, "close", ""))) { + goto exit; + } /* success (of curs.close()). * Return None to avoid swallowing the exception */ diff --git a/tests/test_with.py b/tests/test_with.py index 51889270..f43e6db4 100755 --- a/tests/test_with.py +++ b/tests/test_with.py @@ -115,6 +115,48 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase): self.conn.close() self.assertRaises(psycopg2.InterfaceError, f) + def test_subclass_commit(self): + commits = [] + class MyConn(ext.connection): + def commit(self): + commits.append(None) + super(MyConn, self).commit() + + with psycopg2.connect(dsn, connection_factory=MyConn) as conn: + curs = conn.cursor() + curs.execute("insert into test_with values (10)") + + self.assertEqual(conn.status, ext.STATUS_READY) + self.assert_(commits) + + curs = self.conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(10,)]) + + def test_subclass_rollback(self): + rollbacks = [] + class MyConn(ext.connection): + def rollback(self): + rollbacks.append(None) + super(MyConn, self).rollback() + + try: + with psycopg2.connect(dsn, connection_factory=MyConn) as conn: + curs = conn.cursor() + curs.execute("insert into test_with values (11)") + 1/0 + except ZeroDivisionError: + pass + else: + self.assert_("exception not raised") + + self.assertEqual(conn.status, ext.STATUS_READY) + self.assert_(rollbacks) + + curs = conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) + class WithCursorTestCase(TestMixin, unittest.TestCase): def test_with_ok(self): @@ -149,6 +191,19 @@ class WithCursorTestCase(TestMixin, unittest.TestCase): curs.execute("select * from test_with") self.assertEqual(curs.fetchall(), []) + def test_subclass(self): + closes = [] + class MyCurs(ext.cursor): + def close(self): + closes.append(None) + super(MyCurs, self).close() + + with self.conn.cursor(cursor_factory=MyCurs) as curs: + self.assert_(isinstance(curs, MyCurs)) + + self.assert_(curs.closed) + self.assert_(closes) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)