From cc605032f5c0c00bee21ab34e5e13b9d866a795e Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Mon, 3 Dec 2012 02:50:24 +0000 Subject: [PATCH] Added support for with statement for connection and cursor The implementation should be conform to the DBAPI, although the "with" extension has not been released yet. --- psycopg/connection_type.c | 48 ++++++++++++ psycopg/cursor_type.c | 38 +++++++++ tests/__init__.py | 7 ++ tests/test_with.py | 157 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 250 insertions(+) create mode 100755 tests/test_with.py diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index 69ab7c83..c1d6176e 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -389,6 +389,50 @@ psyco_conn_tpc_recover(connectionObject *self) } +#define psyco_conn_enter_doc \ +"__enter__ -> self" + +static PyObject * +psyco_conn_enter(connectionObject *self) +{ + EXC_IF_CONN_CLOSED(self); + + Py_INCREF(self); + return (PyObject *)self; +} + + +#define psyco_conn_exit_doc \ +"__exit__ -- commit if no exception, else roll back" + +static PyObject * +psyco_conn_exit(connectionObject *self, PyObject *args) +{ + PyObject *type, *name, *tb; + PyObject *tmp = NULL; + PyObject *rv = NULL; + + if (!PyArg_ParseTuple(args, "OOO", &type, &name, &tb)) { + goto exit; + } + + if (type == Py_None) { + if (!(tmp = psyco_conn_commit(self))) { goto exit; } + } else { + if (!(tmp = psyco_conn_rollback(self))) { goto exit; } + } + + /* success (of the commit or rollback, there may have been an exception in + * the block). Return None to avoid swallowing the exception */ + rv = Py_None; + Py_INCREF(rv); + +exit: + Py_XDECREF(tmp); + return rv; +} + + #ifdef PSYCOPG_EXTENSIONS @@ -924,6 +968,10 @@ static struct PyMethodDef connectionObject_methods[] = { METH_VARARGS, psyco_conn_tpc_rollback_doc}, {"tpc_recover", (PyCFunction)psyco_conn_tpc_recover, METH_NOARGS, psyco_conn_tpc_recover_doc}, + {"__enter__", (PyCFunction)psyco_conn_enter, + METH_NOARGS, psyco_conn_enter_doc}, + {"__exit__", (PyCFunction)psyco_conn_exit, + METH_VARARGS, psyco_conn_exit_doc}, #ifdef PSYCOPG_EXTENSIONS {"set_session", (PyCFunction)psyco_conn_set_session, METH_VARARGS|METH_KEYWORDS, psyco_conn_set_session_doc}, diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index b2c55aa3..5e17bff1 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1201,6 +1201,40 @@ psyco_curs_scroll(cursorObject *self, PyObject *args, PyObject *kwargs) } +#define psyco_curs_enter_doc \ +"__enter__ -> self" + +static PyObject * +psyco_curs_enter(cursorObject *self) +{ + Py_INCREF(self); + return (PyObject *)self; +} + +#define psyco_curs_exit_doc \ +"__exit__ -- close the cursor" + +static PyObject * +psyco_curs_exit(cursorObject *self, PyObject *args) +{ + PyObject *tmp = NULL; + PyObject *rv = NULL; + + /* don't care about the arguments here: don't need to parse them */ + + if (!(tmp = psyco_curs_close(self))) { goto exit; } + + /* success (of curs.close()). + * Return None to avoid swallowing the exception */ + rv = Py_None; + Py_INCREF(rv); + +exit: + Py_XDECREF(tmp); + return rv; +} + + #ifdef PSYCOPG_EXTENSIONS /* Return a newly allocated buffer containing the list of columns to be @@ -1716,6 +1750,10 @@ static struct PyMethodDef cursorObject_methods[] = { /* DBAPI-2.0 extensions */ {"scroll", (PyCFunction)psyco_curs_scroll, METH_VARARGS|METH_KEYWORDS, psyco_curs_scroll_doc}, + {"__enter__", (PyCFunction)psyco_curs_enter, + METH_NOARGS, psyco_curs_enter_doc}, + {"__exit__", (PyCFunction)psyco_curs_exit, + METH_VARARGS, psyco_curs_exit_doc}, /* psycopg extensions */ #ifdef PSYCOPG_EXTENSIONS {"cast", (PyCFunction)psyco_curs_cast, diff --git a/tests/__init__.py b/tests/__init__.py index df8e8cd1..3e677d85 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -45,6 +45,11 @@ import test_transaction import test_types_basic import test_types_extras +if sys.version_info[:2] >= (2, 5): + import test_with +else: + test_with = None + def test_suite(): # If connection to test db fails, bail out early. import psycopg2 @@ -76,6 +81,8 @@ def test_suite(): suite.addTest(test_transaction.test_suite()) suite.addTest(test_types_basic.test_suite()) suite.addTest(test_types_extras.test_suite()) + if test_with: + suite.addTest(test_with.test_suite()) return suite if __name__ == '__main__': diff --git a/tests/test_with.py b/tests/test_with.py new file mode 100755 index 00000000..51889270 --- /dev/null +++ b/tests/test_with.py @@ -0,0 +1,157 @@ +#!/usr/bin/env python + +# test_ctxman.py - unit test for connection and cursor used as context manager +# +# Copyright (C) 2012 Daniele Varrazzo +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + + +from __future__ import with_statement + +import psycopg2 +import psycopg2.extensions as ext + +from testconfig import dsn +from testutils import unittest + +class TestMixin(object): + def setUp(self): + self.conn = conn = psycopg2.connect(dsn) + curs = conn.cursor() + try: + curs.execute("delete from test_with") + conn.commit() + except psycopg2.ProgrammingError: + # assume table doesn't exist + conn.rollback() + curs.execute("create table test_with (id integer primary key)") + conn.commit() + + def tearDown(self): + self.conn.close() + + +class WithConnectionTestCase(TestMixin, unittest.TestCase): + def test_with_ok(self): + with self.conn as conn: + self.assert_(self.conn is conn) + self.assertEqual(conn.status, ext.STATUS_READY) + curs = conn.cursor() + curs.execute("insert into test_with values (1)") + self.assertEqual(conn.status, ext.STATUS_BEGIN) + + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assert_(not self.conn.closed) + + curs = self.conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(1,)]) + + def test_with_connect_idiom(self): + with psycopg2.connect(dsn) as conn: + self.assertEqual(conn.status, ext.STATUS_READY) + curs = conn.cursor() + curs.execute("insert into test_with values (2)") + self.assertEqual(conn.status, ext.STATUS_BEGIN) + + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assert_(not self.conn.closed) + + curs = self.conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(2,)]) + + def test_with_error_db(self): + def f(): + with self.conn as conn: + curs = conn.cursor() + curs.execute("insert into test_with values ('a')") + + self.assertRaises(psycopg2.DataError, f) + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assert_(not self.conn.closed) + + curs = self.conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) + + def test_with_error_python(self): + def f(): + with self.conn as conn: + curs = conn.cursor() + curs.execute("insert into test_with values (3)") + 1/0 + + self.assertRaises(ZeroDivisionError, f) + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assert_(not self.conn.closed) + + curs = self.conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) + + def test_with_closed(self): + def f(): + with self.conn: + pass + + self.conn.close() + self.assertRaises(psycopg2.InterfaceError, f) + + +class WithCursorTestCase(TestMixin, unittest.TestCase): + def test_with_ok(self): + with self.conn as conn: + with conn.cursor() as curs: + curs.execute("insert into test_with values (4)") + self.assert_(not curs.closed) + self.assertEqual(self.conn.status, ext.STATUS_BEGIN) + self.assert_(curs.closed) + + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assert_(not self.conn.closed) + + curs = self.conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), [(4,)]) + + def test_with_error(self): + try: + with self.conn as conn: + with conn.cursor() as curs: + curs.execute("insert into test_with values (5)") + 1/0 + except ZeroDivisionError: + pass + + self.assertEqual(self.conn.status, ext.STATUS_READY) + self.assert_(not self.conn.closed) + self.assert_(curs.closed) + + curs = self.conn.cursor() + curs.execute("select * from test_with") + self.assertEqual(curs.fetchall(), []) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main()