diff --git a/ChangeLog b/ChangeLog index f2775eb4..6f138b04 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,5 +1,12 @@ 2008-01-19 James Henstridge + * tests/test_connection.py (ConnectionTests): add simple tests for + the Connection and Cursor "closed" attributes. + + * psycopg/cursor_type.c (psyco_curs_get_closed): add a "closed" + attribute to cursors. It will be True if either the cursor or its + associated connection are closed. This fixes bug #164. + * psycopg/pqpath.c (pq_raise): remove unused arguments to function, and simplify. (pq_resolve_critical): make function static, since it isn't being diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 225b299e..083826f1 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1431,6 +1431,22 @@ psyco_curs_isready(cursorObject *self, PyObject *args) } } +/* extension: closed - return true if cursor is closed*/ + +#define psyco_curs_closed_doc \ +"True if cursor is closed, False if cursor is open" + +static PyObject * +psyco_curs_get_closed(cursorObject *self, void *closure) +{ + PyObject *closed; + + closed = (self->closed || (self->conn && self->conn->closed)) ? + Py_True : Py_False; + Py_INCREF(closed); + return closed; +} + #endif @@ -1542,6 +1558,15 @@ static struct PyMemberDef cursorObject_members[] = { {NULL} }; +/* object calculated member list */ +static struct PyGetSetDef cursorObject_getsets[] = { +#ifdef PSYCOPG_EXTENSIONS + { "closed", (getter)psyco_curs_get_closed, NULL, + psyco_curs_closed_doc, NULL }, +#endif + {NULL} +}; + /* initialization and finalization methods */ static int @@ -1703,7 +1728,7 @@ PyTypeObject cursorType = { cursorObject_methods, /*tp_methods*/ cursorObject_members, /*tp_members*/ - 0, /*tp_getset*/ + cursorObject_getsets, /*tp_getset*/ 0, /*tp_base*/ 0, /*tp_dict*/ diff --git a/tests/__init__.py b/tests/__init__.py index 1b8adbe7..89dbfdaa 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -9,6 +9,7 @@ import extras_dictcursor import test_dates import test_psycopg2_dbapi20 import test_quote +import test_connection import test_transaction import types_basic @@ -19,6 +20,7 @@ def test_suite(): suite.addTest(test_dates.test_suite()) suite.addTest(test_psycopg2_dbapi20.test_suite()) suite.addTest(test_quote.test_suite()) + suite.addTest(test_connection.test_suite()) suite.addTest(test_transaction.test_suite()) suite.addTest(types_basic.test_suite()) return suite diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 00000000..0fe73a91 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +import unittest + +import psycopg2 +import tests + + +class ConnectionTests(unittest.TestCase): + + def connect(self): + return psycopg2.connect("dbname=%s" % tests.dbname) + + def test_closed_attribute(self): + conn = self.connect() + self.assertEqual(conn.closed, False) + conn.close() + self.assertEqual(conn.closed, True) + + def test_cursor_closed_attribute(self): + conn = self.connect() + curs = conn.cursor() + self.assertEqual(curs.closed, False) + curs.close() + self.assertEqual(curs.closed, True) + + # Closing the connection closes the cursor: + curs = conn.cursor() + conn.close() + self.assertEqual(curs.closed, True) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + diff --git a/tests/test_transaction.py b/tests/test_transaction.py index bd96ce49..5145ace0 100755 --- a/tests/test_transaction.py +++ b/tests/test_transaction.py @@ -3,13 +3,12 @@ import threading import unittest import psycopg2 -import psycopg2 -import tests - from psycopg2.extensions import ( ISOLATION_LEVEL_SERIALIZABLE, STATUS_BEGIN, STATUS_READY) +import tests -class TransactionTestCase(unittest.TestCase): + +class TransactionTests(unittest.TestCase): def setUp(self): self.conn = psycopg2.connect("dbname=%s" % tests.dbname) @@ -72,7 +71,7 @@ class TransactionTestCase(unittest.TestCase): self.assertEqual(curs.fetchone()[0], 1) -class DeadlockSerializationTestCase(unittest.TestCase): +class DeadlockSerializationTests(unittest.TestCase): """Test deadlock and serialization failure errors.""" def connect(self): @@ -223,6 +222,3 @@ class QueryCancelationTests(unittest.TestCase): def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) -if __name__ == "__main__": - unittest.main() -