mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-22 08:56:34 +03:00
Make sure to call subclasses methods on context exit
This commit is contained in:
parent
c2f284cd3b
commit
12645db754
|
@ -417,9 +417,13 @@ psyco_conn_exit(connectionObject *self, PyObject *args)
|
|||
}
|
||||
|
||||
if (type == Py_None) {
|
||||
if (!(tmp = psyco_conn_commit(self))) { goto exit; }
|
||||
if (!(tmp = PyObject_CallMethod((PyObject *)self, "commit", ""))) {
|
||||
goto exit;
|
||||
}
|
||||
} else {
|
||||
if (!(tmp = psyco_conn_rollback(self))) { goto exit; }
|
||||
if (!(tmp = PyObject_CallMethod((PyObject *)self, "rollback", ""))) {
|
||||
goto exit;
|
||||
}
|
||||
}
|
||||
|
||||
/* success (of the commit or rollback, there may have been an exception in
|
||||
|
|
|
@ -1222,7 +1222,9 @@ psyco_curs_exit(cursorObject *self, PyObject *args)
|
|||
|
||||
/* don't care about the arguments here: don't need to parse them */
|
||||
|
||||
if (!(tmp = psyco_curs_close(self))) { goto exit; }
|
||||
if (!(tmp = PyObject_CallMethod((PyObject *)self, "close", ""))) {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
/* success (of curs.close()).
|
||||
* Return None to avoid swallowing the exception */
|
||||
|
|
|
@ -115,6 +115,48 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
|
|||
self.conn.close()
|
||||
self.assertRaises(psycopg2.InterfaceError, f)
|
||||
|
||||
def test_subclass_commit(self):
|
||||
commits = []
|
||||
class MyConn(ext.connection):
|
||||
def commit(self):
|
||||
commits.append(None)
|
||||
super(MyConn, self).commit()
|
||||
|
||||
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values (10)")
|
||||
|
||||
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||
self.assert_(commits)
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [(10,)])
|
||||
|
||||
def test_subclass_rollback(self):
|
||||
rollbacks = []
|
||||
class MyConn(ext.connection):
|
||||
def rollback(self):
|
||||
rollbacks.append(None)
|
||||
super(MyConn, self).rollback()
|
||||
|
||||
try:
|
||||
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
|
||||
curs = conn.cursor()
|
||||
curs.execute("insert into test_with values (11)")
|
||||
1/0
|
||||
except ZeroDivisionError:
|
||||
pass
|
||||
else:
|
||||
self.assert_("exception not raised")
|
||||
|
||||
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||
self.assert_(rollbacks)
|
||||
|
||||
curs = conn.cursor()
|
||||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [])
|
||||
|
||||
|
||||
class WithCursorTestCase(TestMixin, unittest.TestCase):
|
||||
def test_with_ok(self):
|
||||
|
@ -149,6 +191,19 @@ class WithCursorTestCase(TestMixin, unittest.TestCase):
|
|||
curs.execute("select * from test_with")
|
||||
self.assertEqual(curs.fetchall(), [])
|
||||
|
||||
def test_subclass(self):
|
||||
closes = []
|
||||
class MyCurs(ext.cursor):
|
||||
def close(self):
|
||||
closes.append(None)
|
||||
super(MyCurs, self).close()
|
||||
|
||||
with self.conn.cursor(cursor_factory=MyCurs) as curs:
|
||||
self.assert_(isinstance(curs, MyCurs))
|
||||
|
||||
self.assert_(curs.closed)
|
||||
self.assert_(closes)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
|
Loading…
Reference in New Issue
Block a user