mirror of
				https://github.com/psycopg/psycopg2.git
				synced 2025-11-04 09:47:30 +03:00 
			
		
		
		
	Make sure to call subclasses methods on context exit
This commit is contained in:
		
							parent
							
								
									c2f284cd3b
								
							
						
					
					
						commit
						12645db754
					
				| 
						 | 
					@ -417,9 +417,13 @@ psyco_conn_exit(connectionObject *self, PyObject *args)
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (type == Py_None) {
 | 
					    if (type == Py_None) {
 | 
				
			||||||
        if (!(tmp = psyco_conn_commit(self))) { goto exit; }
 | 
					        if (!(tmp = PyObject_CallMethod((PyObject *)self, "commit", ""))) {
 | 
				
			||||||
 | 
					            goto exit;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    } else {
 | 
					    } else {
 | 
				
			||||||
        if (!(tmp = psyco_conn_rollback(self))) { goto exit; }
 | 
					        if (!(tmp = PyObject_CallMethod((PyObject *)self, "rollback", ""))) {
 | 
				
			||||||
 | 
					            goto exit;
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
    }
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* success (of the commit or rollback, there may have been an exception in
 | 
					    /* success (of the commit or rollback, there may have been an exception in
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1222,7 +1222,9 @@ psyco_curs_exit(cursorObject *self, PyObject *args)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* don't care about the arguments here: don't need to parse them */
 | 
					    /* don't care about the arguments here: don't need to parse them */
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if (!(tmp = psyco_curs_close(self))) { goto exit; }
 | 
					    if (!(tmp = PyObject_CallMethod((PyObject *)self, "close", ""))) {
 | 
				
			||||||
 | 
					        goto exit;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    /* success (of curs.close()).
 | 
					    /* success (of curs.close()).
 | 
				
			||||||
     * Return None to avoid swallowing the exception */
 | 
					     * Return None to avoid swallowing the exception */
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -115,6 +115,48 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
 | 
				
			||||||
        self.conn.close()
 | 
					        self.conn.close()
 | 
				
			||||||
        self.assertRaises(psycopg2.InterfaceError, f)
 | 
					        self.assertRaises(psycopg2.InterfaceError, f)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_subclass_commit(self):
 | 
				
			||||||
 | 
					        commits = []
 | 
				
			||||||
 | 
					        class MyConn(ext.connection):
 | 
				
			||||||
 | 
					            def commit(self):
 | 
				
			||||||
 | 
					                commits.append(None)
 | 
				
			||||||
 | 
					                super(MyConn, self).commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
 | 
				
			||||||
 | 
					            curs = conn.cursor()
 | 
				
			||||||
 | 
					            curs.execute("insert into test_with values (10)")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(conn.status, ext.STATUS_READY)
 | 
				
			||||||
 | 
					        self.assert_(commits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        curs = self.conn.cursor()
 | 
				
			||||||
 | 
					        curs.execute("select * from test_with")
 | 
				
			||||||
 | 
					        self.assertEqual(curs.fetchall(), [(10,)])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_subclass_rollback(self):
 | 
				
			||||||
 | 
					        rollbacks = []
 | 
				
			||||||
 | 
					        class MyConn(ext.connection):
 | 
				
			||||||
 | 
					            def rollback(self):
 | 
				
			||||||
 | 
					                rollbacks.append(None)
 | 
				
			||||||
 | 
					                super(MyConn, self).rollback()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
 | 
				
			||||||
 | 
					                curs = conn.cursor()
 | 
				
			||||||
 | 
					                curs.execute("insert into test_with values (11)")
 | 
				
			||||||
 | 
					                1/0
 | 
				
			||||||
 | 
					        except ZeroDivisionError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.assert_("exception not raised")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(conn.status, ext.STATUS_READY)
 | 
				
			||||||
 | 
					        self.assert_(rollbacks)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        curs = conn.cursor()
 | 
				
			||||||
 | 
					        curs.execute("select * from test_with")
 | 
				
			||||||
 | 
					        self.assertEqual(curs.fetchall(), [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WithCursorTestCase(TestMixin, unittest.TestCase):
 | 
					class WithCursorTestCase(TestMixin, unittest.TestCase):
 | 
				
			||||||
    def test_with_ok(self):
 | 
					    def test_with_ok(self):
 | 
				
			||||||
| 
						 | 
					@ -149,6 +191,19 @@ class WithCursorTestCase(TestMixin, unittest.TestCase):
 | 
				
			||||||
        curs.execute("select * from test_with")
 | 
					        curs.execute("select * from test_with")
 | 
				
			||||||
        self.assertEqual(curs.fetchall(), [])
 | 
					        self.assertEqual(curs.fetchall(), [])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_subclass(self):
 | 
				
			||||||
 | 
					        closes = []
 | 
				
			||||||
 | 
					        class MyCurs(ext.cursor):
 | 
				
			||||||
 | 
					            def close(self):
 | 
				
			||||||
 | 
					                closes.append(None)
 | 
				
			||||||
 | 
					                super(MyCurs, self).close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        with self.conn.cursor(cursor_factory=MyCurs) as curs:
 | 
				
			||||||
 | 
					            self.assert_(isinstance(curs, MyCurs))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assert_(curs.closed)
 | 
				
			||||||
 | 
					        self.assert_(closes)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_suite():
 | 
					def test_suite():
 | 
				
			||||||
    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
					    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user