From 6108e4dc9258f409a8516af50dde76c9ed5553f2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jan=20Urba=C5=84ski?= Date: Sat, 10 Apr 2010 18:54:49 +0200 Subject: [PATCH] Make the first poll() of an asynchronous connection return POLL_WRITE. This hides from the user the libpq's implementation detail of requiring the first select() to wait for the connection socket to become writable and makes it possible to have a uniform select loop for both cursors and connections, in which you always start by polling the object and then acting according to the result from poll(). Idea and implementation by Daniele Varrazzo. --- psycopg/connection.h | 1 + psycopg/connection_type.c | 8 ++++- tests/test_async.py | 69 +++++++++++++++++++-------------------- 3 files changed, 41 insertions(+), 37 deletions(-) diff --git a/psycopg/connection.h b/psycopg/connection.h index ec8e8691..f2296a9d 100644 --- a/psycopg/connection.h +++ b/psycopg/connection.h @@ -37,6 +37,7 @@ extern "C" { #endif /* connection status */ +#define CONN_STATUS_SETUP 0 #define CONN_STATUS_READY 1 #define CONN_STATUS_BEGIN 2 #define CONN_STATUS_SYNC 3 diff --git a/psycopg/connection_type.c b/psycopg/connection_type.c index 3fd08a48..ab5b23ed 100644 --- a/psycopg/connection_type.c +++ b/psycopg/connection_type.c @@ -425,6 +425,12 @@ psyco_conn_poll(connectionObject *self) switch (self->status) { + case CONN_STATUS_SETUP: + /* according to libpq documentation the user should start by waiting + for the socket to become writable */ + self->status = CONN_STATUS_ASYNC; + return PyInt_FromLong(PSYCO_POLL_WRITE); + case CONN_STATUS_SEND_DATESTYLE: case CONN_STATUS_SENT_DATESTYLE: case CONN_STATUS_SEND_CLIENT_ENCODING: @@ -686,7 +692,7 @@ connection_setup(connectionObject *self, const char *dsn, long int async) self->notifies = PyList_New(0); self->closed = 0; self->async = async; - self->status = async ? CONN_STATUS_ASYNC : CONN_STATUS_READY; + self->status = async ? CONN_STATUS_SETUP : CONN_STATUS_READY; self->critical = NULL; self->async_cursor = NULL; self->pgconn = NULL; diff --git a/tests/test_async.py b/tests/test_async.py index 6cb91490..de853e7e 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -35,33 +35,30 @@ class AsyncTests(unittest.TestCase): self.sync_conn = psycopg2.connect(tests.dsn) self.conn = psycopg2.connect(tests.dsn, async=True) - state = psycopg2.extensions.POLL_WRITE - while state != psycopg2.extensions.POLL_OK: - if state == psycopg2.extensions.POLL_WRITE: - select.select([], [self.conn.fileno()], []) - elif state == psycopg2.extensions.POLL_READ: - select.select([self.conn.fileno()], [], []) - state = self.conn.poll() + self.wait(self.conn) curs = self.conn.cursor() curs.execute(''' CREATE TEMPORARY TABLE table1 ( id int PRIMARY KEY )''') - self.wait_for_query(curs) + self.wait(curs) def tearDown(self): self.sync_conn.close() self.conn.close() - def wait_for_query(self, cur): - state = cur.poll() - while state != psycopg2.extensions.POLL_OK: - if state == psycopg2.extensions.POLL_READ: - select.select([cur.fileno()], [], []) + def wait(self, pollable): + while True: + state = pollable.poll() + if state == psycopg2.extensions.POLL_OK: + break + elif state == psycopg2.extensions.POLL_READ: + select.select([pollable], [], []) elif state == psycopg2.extensions.POLL_WRITE: - select.select([], [cur.fileno()], []) - state = cur.poll() + select.select([], [pollable], []) + else: + raise Exception("Unexpected result from poll: %r", state) def test_connection_setup(self): cur = self.conn.cursor() @@ -83,7 +80,7 @@ class AsyncTests(unittest.TestCase): cur.execute("select 'a'") self.assertTrue(self.conn.executing()) - self.wait_for_query(cur) + self.wait(cur) self.assertFalse(self.conn.executing()) self.assertEquals(cur.fetchone()[0], "a") @@ -97,7 +94,7 @@ class AsyncTests(unittest.TestCase): return self.assertTrue(self.conn.executing()) - self.wait_for_query(cur) + self.wait(cur) self.assertFalse(self.conn.executing()) self.assertEquals(cur.fetchall()[0][0], '') @@ -114,17 +111,17 @@ class AsyncTests(unittest.TestCase): self.assertRaises(psycopg2.ProgrammingError, cur.callproc, "version") # but after you've waited it should be good - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchall()[0][0], 1) cur.execute("delete from table1") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchone(), None) @@ -136,7 +133,7 @@ class AsyncTests(unittest.TestCase): self.assertRaises(psycopg2.ProgrammingError, cur.fetchall) # but after waiting it should work - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchall()[0][0], "a") def test_rollback_while_async(self): @@ -151,7 +148,7 @@ class AsyncTests(unittest.TestCase): cur = self.conn.cursor() cur.execute("begin") - self.wait_for_query(cur) + self.wait(cur) cur.execute("insert into table1 values (1)") @@ -160,19 +157,19 @@ class AsyncTests(unittest.TestCase): self.assertTrue(self.conn.executing()) # but a manual commit should - self.wait_for_query(cur) + self.wait(cur) cur.execute("commit") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchall()[0][0], 1) cur.execute("delete from table1") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select * from table1") - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(cur.fetchone(), None) def test_set_parameters_while_async(self): @@ -206,16 +203,16 @@ class AsyncTests(unittest.TestCase): cur = self.conn.cursor() cur.execute("begin") - self.wait_for_query(cur) + self.wait(cur) cur.execute("insert into table1 values (1), (2), (3)") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select id from table1 order by id") # iteration fails if a query is underway self.assertRaises(psycopg2.ProgrammingError, list, cur) # but after it's done it should work - self.wait_for_query(cur) + self.wait(cur) self.assertEquals(list(cur), [(1, ), (2, ), (3, )]) self.assertFalse(self.conn.executing()) @@ -242,7 +239,7 @@ class AsyncTests(unittest.TestCase): def test_async_scroll(self): cur = self.conn.cursor() cur.execute("insert into table1 values (1), (2), (3)") - self.wait_for_query(cur) + self.wait(cur) cur.execute("select id from table1 order by id") # scroll should fail if a query is underway @@ -250,13 +247,13 @@ class AsyncTests(unittest.TestCase): self.assertTrue(self.conn.executing()) # but after it's done it should work - self.wait_for_query(cur) + self.wait(cur) cur.scroll(1) self.assertEquals(cur.fetchall(), [(2, ), (3, )]) cur = self.conn.cursor() cur.execute("select id from table1 order by id") - self.wait_for_query(cur) + self.wait(cur) cur2 = self.conn.cursor() self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) @@ -265,7 +262,7 @@ class AsyncTests(unittest.TestCase): cur = self.conn.cursor() cur.execute("select id from table1 order by id") - self.wait_for_query(cur) + self.wait(cur) cur.scroll(2) cur.scroll(-1) self.assertEquals(cur.fetchall(), [(2, ), (3, )]) @@ -284,7 +281,7 @@ class AsyncTests(unittest.TestCase): cur.execute("select repeat('a', 10000); select repeat('b', 10000)") # fetch the result - self.wait_for_query(cur) + self.wait(cur) # it should be the result of the second query self.assertEquals(cur.fetchone()[0], "b" * 10000)