diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index 25299fab..5caf644d 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -59,7 +59,7 @@ static PyObject * psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) { PyObject *obj = NULL; - PyObject *rv = NULL; + cursorObject *curs = NULL; PyObject *name = Py_None; PyObject *factory = Py_None; PyObject *withhold = Py_False; @@ -110,14 +110,17 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) if (PyObject_IsInstance(obj, (PyObject *)&cursorType) == 0) { PyErr_SetString(PyExc_TypeError, "cursor factory must be subclass of psycopg2.extensions.cursor"); + Py_DECREF(obj); + obj = NULL; goto exit; } - if (0 > curs_withhold_set((cursorObject *)obj, withhold)) { - goto exit; + curs = (cursorObject *)obj; + if (0 > curs_withhold_set(curs, withhold)) { + goto error; } - if (0 > curs_scrollable_set((cursorObject *)obj, scrollable)) { - goto exit; + if (0 > curs_scrollable_set(curs, scrollable)) { + goto error; } Dprintf("psyco_conn_cursor: new cursor at %p: refcnt = " @@ -125,12 +128,26 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) obj, Py_REFCNT(obj) ); - rv = obj; obj = NULL; + goto exit; + +error: + { + PyObject *error_type, *error_value, *error_traceback; + PyObject *close; + curs = NULL; + PyErr_Fetch(&error_type, &error_value, &error_traceback); + close = PyObject_CallMethod(obj, "close", NULL); + if (close) + Py_DECREF(close); + else + PyErr_WriteUnraisable(obj); + PyErr_Restore(error_type, error_value, error_traceback); + } exit: Py_XDECREF(obj); - return rv; + return (PyObject *)curs; } @@ -1366,6 +1383,9 @@ connection_dealloc(PyObject* obj) { connectionObject *self = (connectionObject *)obj; + if (PyObject_CallFinalizerFromDealloc(obj) < 0) + return; + /* Make sure to untrack the connection before calling conn_close, which may * allow a different thread to try and dealloc the connection again, * resulting in a double-free segfault (ticket #166). */ @@ -1405,6 +1425,31 @@ connection_dealloc(PyObject* obj) Py_TYPE(obj)->tp_free(obj); } +#if PY_3 +static void +connection_finalize(PyObject *obj) +{ + connectionObject *self = (connectionObject *)obj; + +#ifdef CONN_CHECK_PID + if (self->procpid == getpid()) +#endif + { + if (!self->closed) { + PyObject *error_type, *error_value, *error_traceback; + /* Save the current exception, if any. */ + PyErr_Fetch(&error_type, &error_value, &error_traceback); + + if (PyErr_WarnFormat(PyExc_ResourceWarning, 1, "unclosed connection %R", obj)) + PyErr_WriteUnraisable(obj); + + /* Restore the saved exception. */ + PyErr_Restore(error_type, error_value, error_traceback); + } + } +} +#endif + static int connection_init(PyObject *obj, PyObject *args, PyObject *kwds) { @@ -1479,7 +1524,7 @@ PyTypeObject connectionType = { 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | - Py_TPFLAGS_HAVE_WEAKREFS, + Py_TPFLAGS_HAVE_WEAKREFS | Py_TPFLAGS_HAVE_FINALIZE, /*tp_flags*/ connectionType_doc, /*tp_doc*/ (traverseproc)connection_traverse, /*tp_traverse*/ @@ -1499,4 +1544,16 @@ PyTypeObject connectionType = { connection_init, /*tp_init*/ 0, /*tp_alloc*/ connection_new, /*tp_new*/ + 0, /* tp_free */ + 0, /* tp_is_gc */ + 0, /* tp_bases */ + 0, /* tp_mro */ + 0, /* tp_cache */ + 0, /* tp_subclasses */ + 0, /* tp_weaklist */ +#if PY_3 + 0, /* tp_del */ + 0, /* tp_version_tag */ + connection_finalize, /* tp_finalize */ +#endif }; diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index a7bd11b4..4ccbc135 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -39,6 +39,7 @@ #include +#define cursor_closed(self) ((self)->closed || ((self)->conn && (self)->conn->closed)) /** DBAPI methods **/ @@ -1637,7 +1638,7 @@ exit: static PyObject * curs_closed_get(cursorObject *self, void *closure) { - return PyBool_FromLong(self->closed || (self->conn && self->conn->closed)); + return PyBool_FromLong(cursor_closed(self)); } /* extension: withhold - get or set "WITH HOLD" for named cursors */ @@ -1945,6 +1946,9 @@ cursor_dealloc(PyObject* obj) { cursorObject *self = (cursorObject *)obj; + if (PyObject_CallFinalizerFromDealloc(obj) < 0) + return; + PyObject_GC_UnTrack(self); if (self->weakreflist) { @@ -1965,6 +1969,28 @@ cursor_dealloc(PyObject* obj) Py_TYPE(obj)->tp_free(obj); } +#if PY_3 +static void +cursor_finalize(PyObject *obj) +{ + cursorObject *self = (cursorObject *)obj; + + if (!cursor_closed(self)) { + PyObject *error_type, *error_value, *error_traceback; + /* Save the current exception, if any. */ + PyErr_Fetch(&error_type, &error_value, &error_traceback); + + if (PyErr_WarnFormat(PyExc_ResourceWarning, 1, + "unclosed cursor %R for connection %R", + obj, (PyObject *)self->conn)) + PyErr_WriteUnraisable(obj); + + /* Restore the saved exception. */ + PyErr_Restore(error_type, error_value, error_traceback); + } +} +#endif + static int cursor_init(PyObject *obj, PyObject *args, PyObject *kwargs) { @@ -2056,7 +2082,7 @@ PyTypeObject cursorType = { 0, /*tp_setattro*/ 0, /*tp_as_buffer*/ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_ITER | - Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_HAVE_WEAKREFS , + Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_HAVE_WEAKREFS | Py_TPFLAGS_HAVE_FINALIZE, /*tp_flags*/ cursorType_doc, /*tp_doc*/ (traverseproc)cursor_traverse, /*tp_traverse*/ @@ -2076,4 +2102,16 @@ PyTypeObject cursorType = { cursor_init, /*tp_init*/ 0, /*tp_alloc*/ cursor_new, /*tp_new*/ + 0, /* tp_free */ + 0, /* tp_is_gc */ + 0, /* tp_bases */ + 0, /* tp_mro */ + 0, /* tp_cache */ + 0, /* tp_subclasses */ + 0, /* tp_weaklist */ +#if PY_3 + 0, /* tp_del */ + 0, /* tp_version_tag */ + cursor_finalize, /* tp_finalize */ +#endif }; diff --git a/psycopg/python.h b/psycopg/python.h index 2a5f9d83..a38231fc 100644 --- a/psycopg/python.h +++ b/psycopg/python.h @@ -86,6 +86,8 @@ typedef unsigned long Py_uhash_t; #define Bytes_ConcatAndDel PyString_ConcatAndDel #define _Bytes_Resize _PyString_Resize +#define Py_TPFLAGS_HAVE_FINALIZE 0L + #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) @@ -97,6 +99,8 @@ typedef unsigned long Py_uhash_t; PyLong_FromUnsignedLong((unsigned long)(x)) : \ PyInt_FromLong((x))) +#define PyObject_CallFinalizerFromDealloc(obj) 0 + #endif /* PY_2 */ #if PY_3 diff --git a/tests/__init__.py b/tests/__init__.py index f5c422f4..a785f8bf 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -56,6 +56,7 @@ from . import test_sql from . import test_transaction from . import test_types_basic from . import test_types_extras +from . import test_warnings from . import test_with if sys.version_info[:2] < (3, 6): @@ -101,6 +102,7 @@ def test_suite(): suite.addTest(test_transaction.test_suite()) suite.addTest(test_types_basic.test_suite()) suite.addTest(test_types_extras.test_suite()) + suite.addTest(test_warnings.test_suite()) suite.addTest(test_with.test_suite()) return suite diff --git a/tests/test_warnings.py b/tests/test_warnings.py new file mode 100644 index 00000000..0644fb10 --- /dev/null +++ b/tests/test_warnings.py @@ -0,0 +1,63 @@ +import unittest +import warnings + +import psycopg2 + +from .testconfig import dsn +from .testutils import skip_before_python + + +class WarningsTest(unittest.TestCase): + @skip_before_python(3) + def test_connection_not_closed(self): + def f(): + psycopg2.connect(dsn) + + msg = ( + "^unclosed connection $" + ) + with self.assertWarnsRegex(ResourceWarning, msg): + f() + + @skip_before_python(3) + def test_cursor_not_closed(self): + def f(): + conn = psycopg2.connect(dsn) + try: + conn.cursor() + finally: + conn.close() + + msg = ( + "^unclosed cursor for " + "connection $" + ) + with self.assertWarnsRegex(ResourceWarning, msg): + f() + + def test_cursor_factory_returns_non_cursor(self): + def bad_factory(*args, **kwargs): + return object() + + def f(): + conn = psycopg2.connect(dsn) + try: + conn.cursor(cursor_factory=bad_factory) + finally: + conn.close() + + with warnings.catch_warnings(record=True) as cm: + with self.assertRaises(TypeError): + f() + + # No warning as no cursor was instantiated. + self.assertEquals(cm, []) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + + +if __name__ == "__main__": + unittest.main()