mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-22 17:06:33 +03:00
Make sure to call subclasses methods on context exit
This commit is contained in:
parent
c2f284cd3b
commit
12645db754
|
@ -417,9 +417,13 @@ psyco_conn_exit(connectionObject *self, PyObject *args)
|
||||||
}
|
}
|
||||||
|
|
||||||
if (type == Py_None) {
|
if (type == Py_None) {
|
||||||
if (!(tmp = psyco_conn_commit(self))) { goto exit; }
|
if (!(tmp = PyObject_CallMethod((PyObject *)self, "commit", ""))) {
|
||||||
|
goto exit;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
if (!(tmp = psyco_conn_rollback(self))) { goto exit; }
|
if (!(tmp = PyObject_CallMethod((PyObject *)self, "rollback", ""))) {
|
||||||
|
goto exit;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* success (of the commit or rollback, there may have been an exception in
|
/* success (of the commit or rollback, there may have been an exception in
|
||||||
|
|
|
@ -1222,7 +1222,9 @@ psyco_curs_exit(cursorObject *self, PyObject *args)
|
||||||
|
|
||||||
/* don't care about the arguments here: don't need to parse them */
|
/* don't care about the arguments here: don't need to parse them */
|
||||||
|
|
||||||
if (!(tmp = psyco_curs_close(self))) { goto exit; }
|
if (!(tmp = PyObject_CallMethod((PyObject *)self, "close", ""))) {
|
||||||
|
goto exit;
|
||||||
|
}
|
||||||
|
|
||||||
/* success (of curs.close()).
|
/* success (of curs.close()).
|
||||||
* Return None to avoid swallowing the exception */
|
* Return None to avoid swallowing the exception */
|
||||||
|
|
|
@ -115,6 +115,48 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
|
||||||
self.conn.close()
|
self.conn.close()
|
||||||
self.assertRaises(psycopg2.InterfaceError, f)
|
self.assertRaises(psycopg2.InterfaceError, f)
|
||||||
|
|
||||||
|
def test_subclass_commit(self):
|
||||||
|
commits = []
|
||||||
|
class MyConn(ext.connection):
|
||||||
|
def commit(self):
|
||||||
|
commits.append(None)
|
||||||
|
super(MyConn, self).commit()
|
||||||
|
|
||||||
|
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
|
||||||
|
curs = conn.cursor()
|
||||||
|
curs.execute("insert into test_with values (10)")
|
||||||
|
|
||||||
|
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||||
|
self.assert_(commits)
|
||||||
|
|
||||||
|
curs = self.conn.cursor()
|
||||||
|
curs.execute("select * from test_with")
|
||||||
|
self.assertEqual(curs.fetchall(), [(10,)])
|
||||||
|
|
||||||
|
def test_subclass_rollback(self):
|
||||||
|
rollbacks = []
|
||||||
|
class MyConn(ext.connection):
|
||||||
|
def rollback(self):
|
||||||
|
rollbacks.append(None)
|
||||||
|
super(MyConn, self).rollback()
|
||||||
|
|
||||||
|
try:
|
||||||
|
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
|
||||||
|
curs = conn.cursor()
|
||||||
|
curs.execute("insert into test_with values (11)")
|
||||||
|
1/0
|
||||||
|
except ZeroDivisionError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
self.assert_("exception not raised")
|
||||||
|
|
||||||
|
self.assertEqual(conn.status, ext.STATUS_READY)
|
||||||
|
self.assert_(rollbacks)
|
||||||
|
|
||||||
|
curs = conn.cursor()
|
||||||
|
curs.execute("select * from test_with")
|
||||||
|
self.assertEqual(curs.fetchall(), [])
|
||||||
|
|
||||||
|
|
||||||
class WithCursorTestCase(TestMixin, unittest.TestCase):
|
class WithCursorTestCase(TestMixin, unittest.TestCase):
|
||||||
def test_with_ok(self):
|
def test_with_ok(self):
|
||||||
|
@ -149,6 +191,19 @@ class WithCursorTestCase(TestMixin, unittest.TestCase):
|
||||||
curs.execute("select * from test_with")
|
curs.execute("select * from test_with")
|
||||||
self.assertEqual(curs.fetchall(), [])
|
self.assertEqual(curs.fetchall(), [])
|
||||||
|
|
||||||
|
def test_subclass(self):
|
||||||
|
closes = []
|
||||||
|
class MyCurs(ext.cursor):
|
||||||
|
def close(self):
|
||||||
|
closes.append(None)
|
||||||
|
super(MyCurs, self).close()
|
||||||
|
|
||||||
|
with self.conn.cursor(cursor_factory=MyCurs) as curs:
|
||||||
|
self.assert_(isinstance(curs, MyCurs))
|
||||||
|
|
||||||
|
self.assert_(curs.closed)
|
||||||
|
self.assert_(closes)
|
||||||
|
|
||||||
|
|
||||||
def test_suite():
|
def test_suite():
|
||||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user