Added support for with statement for connection and cursor

The implementation should be conform to the DBAPI, although the "with"
extension has not been released yet.
This commit is contained in:
Daniele Varrazzo 2012-12-03 02:50:24 +00:00
parent 9f06df1820
commit cc605032f5
4 changed files with 250 additions and 0 deletions

View File

@ -389,6 +389,50 @@ psyco_conn_tpc_recover(connectionObject *self)
}
#define psyco_conn_enter_doc \
"__enter__ -> self"
static PyObject *
psyco_conn_enter(connectionObject *self)
{
EXC_IF_CONN_CLOSED(self);
Py_INCREF(self);
return (PyObject *)self;
}
#define psyco_conn_exit_doc \
"__exit__ -- commit if no exception, else roll back"
static PyObject *
psyco_conn_exit(connectionObject *self, PyObject *args)
{
PyObject *type, *name, *tb;
PyObject *tmp = NULL;
PyObject *rv = NULL;
if (!PyArg_ParseTuple(args, "OOO", &type, &name, &tb)) {
goto exit;
}
if (type == Py_None) {
if (!(tmp = psyco_conn_commit(self))) { goto exit; }
} else {
if (!(tmp = psyco_conn_rollback(self))) { goto exit; }
}
/* success (of the commit or rollback, there may have been an exception in
* the block). Return None to avoid swallowing the exception */
rv = Py_None;
Py_INCREF(rv);
exit:
Py_XDECREF(tmp);
return rv;
}
#ifdef PSYCOPG_EXTENSIONS
@ -924,6 +968,10 @@ static struct PyMethodDef connectionObject_methods[] = {
METH_VARARGS, psyco_conn_tpc_rollback_doc},
{"tpc_recover", (PyCFunction)psyco_conn_tpc_recover,
METH_NOARGS, psyco_conn_tpc_recover_doc},
{"__enter__", (PyCFunction)psyco_conn_enter,
METH_NOARGS, psyco_conn_enter_doc},
{"__exit__", (PyCFunction)psyco_conn_exit,
METH_VARARGS, psyco_conn_exit_doc},
#ifdef PSYCOPG_EXTENSIONS
{"set_session", (PyCFunction)psyco_conn_set_session,
METH_VARARGS|METH_KEYWORDS, psyco_conn_set_session_doc},

View File

@ -1201,6 +1201,40 @@ psyco_curs_scroll(cursorObject *self, PyObject *args, PyObject *kwargs)
}
#define psyco_curs_enter_doc \
"__enter__ -> self"
static PyObject *
psyco_curs_enter(cursorObject *self)
{
Py_INCREF(self);
return (PyObject *)self;
}
#define psyco_curs_exit_doc \
"__exit__ -- close the cursor"
static PyObject *
psyco_curs_exit(cursorObject *self, PyObject *args)
{
PyObject *tmp = NULL;
PyObject *rv = NULL;
/* don't care about the arguments here: don't need to parse them */
if (!(tmp = psyco_curs_close(self))) { goto exit; }
/* success (of curs.close()).
* Return None to avoid swallowing the exception */
rv = Py_None;
Py_INCREF(rv);
exit:
Py_XDECREF(tmp);
return rv;
}
#ifdef PSYCOPG_EXTENSIONS
/* Return a newly allocated buffer containing the list of columns to be
@ -1716,6 +1750,10 @@ static struct PyMethodDef cursorObject_methods[] = {
/* DBAPI-2.0 extensions */
{"scroll", (PyCFunction)psyco_curs_scroll,
METH_VARARGS|METH_KEYWORDS, psyco_curs_scroll_doc},
{"__enter__", (PyCFunction)psyco_curs_enter,
METH_NOARGS, psyco_curs_enter_doc},
{"__exit__", (PyCFunction)psyco_curs_exit,
METH_VARARGS, psyco_curs_exit_doc},
/* psycopg extensions */
#ifdef PSYCOPG_EXTENSIONS
{"cast", (PyCFunction)psyco_curs_cast,

View File

@ -45,6 +45,11 @@ import test_transaction
import test_types_basic
import test_types_extras
if sys.version_info[:2] >= (2, 5):
import test_with
else:
test_with = None
def test_suite():
# If connection to test db fails, bail out early.
import psycopg2
@ -76,6 +81,8 @@ def test_suite():
suite.addTest(test_transaction.test_suite())
suite.addTest(test_types_basic.test_suite())
suite.addTest(test_types_extras.test_suite())
if test_with:
suite.addTest(test_with.test_suite())
return suite
if __name__ == '__main__':

157
tests/test_with.py Executable file
View File

@ -0,0 +1,157 @@
#!/usr/bin/env python
# test_ctxman.py - unit test for connection and cursor used as context manager
#
# Copyright (C) 2012 Daniele Varrazzo <daniele.varrazzo@gmail.com>
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
from __future__ import with_statement
import psycopg2
import psycopg2.extensions as ext
from testconfig import dsn
from testutils import unittest
class TestMixin(object):
def setUp(self):
self.conn = conn = psycopg2.connect(dsn)
curs = conn.cursor()
try:
curs.execute("delete from test_with")
conn.commit()
except psycopg2.ProgrammingError:
# assume table doesn't exist
conn.rollback()
curs.execute("create table test_with (id integer primary key)")
conn.commit()
def tearDown(self):
self.conn.close()
class WithConnectionTestCase(TestMixin, unittest.TestCase):
def test_with_ok(self):
with self.conn as conn:
self.assert_(self.conn is conn)
self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor()
curs.execute("insert into test_with values (1)")
self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(1,)])
def test_with_connect_idiom(self):
with psycopg2.connect(dsn) as conn:
self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor()
curs.execute("insert into test_with values (2)")
self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(2,)])
def test_with_error_db(self):
def f():
with self.conn as conn:
curs = conn.cursor()
curs.execute("insert into test_with values ('a')")
self.assertRaises(psycopg2.DataError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_with_error_python(self):
def f():
with self.conn as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (3)")
1/0
self.assertRaises(ZeroDivisionError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_with_closed(self):
def f():
with self.conn:
pass
self.conn.close()
self.assertRaises(psycopg2.InterfaceError, f)
class WithCursorTestCase(TestMixin, unittest.TestCase):
def test_with_ok(self):
with self.conn as conn:
with conn.cursor() as curs:
curs.execute("insert into test_with values (4)")
self.assert_(not curs.closed)
self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
self.assert_(curs.closed)
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(4,)])
def test_with_error(self):
try:
with self.conn as conn:
with conn.cursor() as curs:
curs.execute("insert into test_with values (5)")
1/0
except ZeroDivisionError:
pass
self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed)
self.assert_(curs.closed)
curs = self.conn.cursor()
curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [])
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()