Make sure to call subclasses methods on context exit

This commit is contained in:
Daniele Varrazzo 2012-12-03 03:37:47 +00:00
parent c2f284cd3b
commit 12645db754
3 changed files with 64 additions and 3 deletions

View File

@ -417,9 +417,13 @@ psyco_conn_exit(connectionObject *self, PyObject *args)
}
if (type == Py_None) {
if (!(tmp = psyco_conn_commit(self))) { goto exit; }
if (!(tmp = PyObject_CallMethod((PyObject *)self, "commit", ""))) {
goto exit;
}
} else {
if (!(tmp = psyco_conn_rollback(self))) { goto exit; }
if (!(tmp = PyObject_CallMethod((PyObject *)self, "rollback", ""))) {
goto exit;
}
}
/* success (of the commit or rollback, there may have been an exception in

View File

@ -1222,7 +1222,9 @@ psyco_curs_exit(cursorObject *self, PyObject *args)
/* don't care about the arguments here: don't need to parse them */
if (!(tmp = psyco_curs_close(self))) { goto exit; }
if (!(tmp = PyObject_CallMethod((PyObject *)self, "close", ""))) {
goto exit;
}
/* success (of curs.close()).
* Return None to avoid swallowing the exception */

View File

@ -115,6 +115,48 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
self.conn.close()
self.assertRaises(psycopg2.InterfaceError, f)
def test_subclass_commit(self):
commits = []
class MyConn(ext.connection):
def commit(self):
commits.append(None)
super(MyConn, self).commit()
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (10)")
self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(commits)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(10,)])
def test_subclass_rollback(self):
rollbacks = []
class MyConn(ext.connection):
def rollback(self):
rollbacks.append(None)
super(MyConn, self).rollback()
try:
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (11)")
1/0
except ZeroDivisionError:
pass
else:
self.assert_("exception not raised")
self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(rollbacks)
curs = conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
class WithCursorTestCase(TestMixin, unittest.TestCase):
def test_with_ok(self):
@ -149,6 +191,19 @@ class WithCursorTestCase(TestMixin, unittest.TestCase):
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_subclass(self):
closes = []
class MyCurs(ext.cursor):
def close(self):
closes.append(None)
super(MyCurs, self).close()
with self.conn.cursor(cursor_factory=MyCurs) as curs:
self.assert_(isinstance(curs, MyCurs))
self.assert_(curs.closed)
self.assert_(closes)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)