diff --git a/psycopg/connection.h b/psycopg/connection.h index 72437492..529a211e 100644 --- a/psycopg/connection.h +++ b/psycopg/connection.h @@ -103,7 +103,7 @@ typedef struct { PGconn *pgconn; /* the postgresql connection */ PGcancel *cancel; /* the cancellation structure */ - PyObject *async_cursor; /* a cursor executing an asynchronous query */ + PyObject *async_cursor; /* weakref to a cursor executing an asynchronous query */ int async_status; /* asynchronous execution status */ /* notice processing */ diff --git a/psycopg/connection_int.c b/psycopg/connection_int.c index 8451b453..9ccc11d5 100644 --- a/psycopg/connection_int.c +++ b/psycopg/connection_int.c @@ -752,7 +752,17 @@ conn_poll(connectionObject *self) if (res == PSYCO_POLL_OK && self->async_cursor) { /* An async query has just finished: parse the tuple in the * target cursor. */ - cursorObject *curs = (cursorObject *)self->async_cursor; + cursorObject *curs; + PyObject *py_curs = PyWeakref_GetObject(self->async_cursor); + if (Py_None == py_curs) { + pq_clear_async(self); + PyErr_SetString(InterfaceError, + "the asynchronous cursor has disappeared"); + res = PSYCO_POLL_ERROR; + break; + } + + curs = (cursorObject *)py_curs; IFCLEARPGRES(curs->pgres); curs->pgres = pq_get_last_result(self); @@ -764,8 +774,7 @@ conn_poll(connectionObject *self) } /* We have finished with our async_cursor */ - Py_XDECREF(self->async_cursor); - self->async_cursor = NULL; + Py_CLEAR(self->async_cursor); } break; diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index be22e34f..5a6722c7 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -792,7 +792,8 @@ psyco_curs_fetchone(cursorObject *self, PyObject *args) /* if the query was async aggresively free pgres, to allow successive requests to reallocate it */ if (self->row >= self->rowcount - && self->conn->async_cursor == (PyObject*)self) + && self->conn->async_cursor + && PyWeakref_GetObject(self->conn->async_cursor) == (PyObject*)self) IFCLEARPGRES(self->pgres); return res; @@ -868,7 +869,8 @@ psyco_curs_fetchmany(cursorObject *self, PyObject *args, PyObject *kwords) /* if the query was async aggresively free pgres, to allow successive requests to reallocate it */ if (self->row >= self->rowcount - && self->conn->async_cursor == (PyObject*)self) + && self->conn->async_cursor + && PyWeakref_GetObject(self->conn->async_cursor) == (PyObject*)self) IFCLEARPGRES(self->pgres); return list; @@ -932,7 +934,8 @@ psyco_curs_fetchall(cursorObject *self, PyObject *args) /* if the query was async aggresively free pgres, to allow successive requests to reallocate it */ if (self->row >= self->rowcount - && self->conn->async_cursor == (PyObject*)self) + && self->conn->async_cursor + && PyWeakref_GetObject(self->conn->async_cursor) == (PyObject*)self) IFCLEARPGRES(self->pgres); return list; diff --git a/psycopg/pqpath.c b/psycopg/pqpath.c index f334b636..27e958e5 100644 --- a/psycopg/pqpath.c +++ b/psycopg/pqpath.c @@ -279,8 +279,7 @@ pq_clear_async(connectionObject *conn) Dprintf("pq_clear_async: clearing PGresult at %p", pgres); CLEARPGRES(pgres); } - Py_XDECREF(conn->async_cursor); - conn->async_cursor = NULL; + Py_CLEAR(conn->async_cursor); } @@ -824,8 +823,11 @@ pq_execute(cursorObject *curs, const char *query, int async) } else { curs->conn->async_status = async_status; - Py_INCREF(curs); - curs->conn->async_cursor = (PyObject*)curs; + curs->conn->async_cursor = PyWeakref_NewRef((PyObject *)curs, NULL); + if (!curs->conn->async_cursor) { + /* weakref creation failed */ + return -1; + } } return 1-async; diff --git a/tests/test_async.py b/tests/test_async.py index 96d7a2cc..d4854fc5 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -415,6 +415,18 @@ class AsyncTests(unittest.TestCase): self.assertEqual("CREATE TABLE", cur.statusmessage) self.assert_(self.conn.notices) + def test_async_cursor_gone(self): + cur = self.conn.cursor() + cur.execute("select 42;"); + del cur + self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn) + + # The connection is still usable + cur = self.conn.cursor() + cur.execute("select 42;"); + self.wait(self.conn) + self.assertEqual(cur.fetchone(), (42,)) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)