Correctly flush async queries in 'green' mode.

This commit is contained in:
Daniele Varrazzo 2010-04-04 00:30:26 +01:00
parent 8ba0f00d21
commit 0dd5d3f1d9
4 changed files with 93 additions and 6 deletions

View File

@ -701,6 +701,7 @@ conn_poll_green(connectionObject *self)
Dprintf("conn_poll: status = CONN_STATUS_READY/BEGIN");
switch (self->async_status) {
case ASYNC_READ:
Dprintf("conn_poll: async_status = ASYNC_READ");
if (0 == PQconsumeInput(self->pgconn)) {
PyErr_SetString(OperationalError, PQerrorMessage(self->pgconn));
res = PSYCO_POLL_ERROR;
@ -712,6 +713,22 @@ conn_poll_green(connectionObject *self)
}
break;
case ASYNC_WRITE:
Dprintf("conn_poll: async_status = ASYNC_WRITE");
switch (PQflush(self->pgconn)) {
case 0: /* success */
res = PSYCO_POLL_OK;
break;
case 1: /* would block */
res = PSYCO_POLL_WRITE;
break;
case -1: /* error */
PyErr_SetString(OperationalError, PQerrorMessage(self->pgconn));
res = PSYCO_POLL_ERROR;
break;
}
break;
default:
Dprintf("conn_poll: in unexpected async status: %d",
self->async_status);

View File

@ -144,7 +144,7 @@ psyco_exec_green(connectionObject *conn, const char *command)
{
PGconn *pgconn = conn->pgconn;
PGresult *result = NULL, *res;
PyObject *cb;
PyObject *cb, *pyrv;
if (!(cb = have_wait_callback())) {
goto end;
@ -153,18 +153,26 @@ psyco_exec_green(connectionObject *conn, const char *command)
/* Send the query asynchronously */
Dprintf("psyco_exec_green: sending query async");
if (0 == PQsendQuery(pgconn, command)) {
/* TODO: not handling the case of block during send */
Dprintf("psyco_exec_green: PQsendQuery returned 0");
goto clear;
}
/* Loop reading data using the user-provided wait function */
conn->async_status = ASYNC_READ;
PyObject *pyrv;
/* Ensure the query reached the server. */
conn->async_status = ASYNC_WRITE;
pyrv = PyObject_CallFunctionObjArgs(cb, conn, NULL, NULL);
if (!pyrv) {
Dprintf("psyco_exec_green: error in callback");
Dprintf("psyco_exec_green: error in callback sending query");
goto clear;
}
Py_DECREF(pyrv);
/* Loop reading data using the user-provided wait function */
conn->async_status = ASYNC_READ;
pyrv = PyObject_CallFunctionObjArgs(cb, conn, NULL, NULL);
if (!pyrv) {
Dprintf("psyco_exec_green: error in callback reading result");
goto clear;
}
Py_DECREF(pyrv);

View File

@ -37,6 +37,7 @@ import test_lobject
import test_copy
import test_notify
import test_async
import test_green
def test_suite():
suite = unittest.TestSuite()
@ -53,6 +54,7 @@ def test_suite():
suite.addTest(test_copy.test_suite())
suite.addTest(test_notify.test_suite())
suite.addTest(test_async.test_suite())
suite.addTest(test_green.test_suite())
return suite
if __name__ == '__main__':

60
tests/test_green.py Normal file
View File

@ -0,0 +1,60 @@
#!/usr/bin/env python
import unittest
import psycopg2
import psycopg2.extensions
import psycopg2.extras
import tests
class ConnectionStub(object):
"""A `connection` wrapper allowing analysis of the `poll()` calls."""
def __init__(self, conn):
self.conn = conn
self.polls = []
def fileno(self):
return self.conn.fileno()
def poll(self):
rv = self.conn.poll()
self.polls.append(rv)
return rv
class GreenTests(unittest.TestCase):
def connect(self):
return psycopg2.connect(tests.dsn)
def setUp(self):
self._cb = psycopg2.extensions.get_wait_callback()
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
def tearDown(self):
psycopg2.extensions.set_wait_callback(self._cb)
def set_stub_wait_callback(self, conn):
stub = ConnectionStub(conn)
psycopg2.extensions.set_wait_callback(
lambda conn: psycopg2.extras.wait_select(stub))
return stub
def test_flush_on_write(self):
# a very large query requires a flush loop to be sent to the backend
conn = self.connect()
stub = self.set_stub_wait_callback(conn)
curs = conn.cursor()
for mb in 1, 5, 10, 20, 50:
size = mb * 1024 * 1024
del stub.polls[:]
curs.execute("select %s;", ('x' * size,))
self.assertEqual(size, len(curs.fetchone()[0]))
if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1:
return
self.fail("sending a large query didn't trigger block on write.")
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()