diff --git a/doc/src/extras.rst b/doc/src/extras.rst index 356e10e0..a9ba52fc 100644 --- a/doc/src/extras.rst +++ b/doc/src/extras.rst @@ -333,9 +333,9 @@ The individual messages in the replication stream are presented by Start replication on the connection using provided ``START_REPLICATION`` command. - .. method:: consume_replication_stream(consumer, decode=False, keepalive_interval=10) + .. method:: consume_replication_stream(consume, decode=False, keepalive_interval=10) - :param consumer: an object providing ``consume()`` method + :param consume: a callable object with signature ``consume(msg)`` :param decode: a flag indicating that unicode conversion should be performed on the messages received from the server :param keepalive_interval: interval (in seconds) to send keepalive @@ -348,10 +348,9 @@ The individual messages in the replication stream are presented by `start_replication()` first. When called, this method enters an endless loop, reading messages from - the server and passing them to ``consume()`` method of the *consumer* - object. In order to make this method break out of the loop and - return, the ``consume()`` method can call `stop_replication()` on the - cursor or it can throw an exception. + the server and passing them to ``consume()``. In order to make this + method break out of the loop and return, ``consume()`` can call + `stop_replication()` on the cursor or it can throw an exception. If *decode* is set to `!True`, the messages read from the server are converted according to the connection `~connection.encoding`. This @@ -362,12 +361,12 @@ The individual messages in the replication stream are presented by *keepalive_interval* (in seconds). The value of this parameter must be equal to at least 1 second, but it can have a fractional part. - The following example is a sketch implementation of *consumer* object - for logical replication:: + The following example is a sketch implementation of ``consume()`` + callable for logical replication:: class LogicalStreamConsumer(object): - def consume(self, msg): + def __call__(self, msg): self.store_message_data(msg.payload) if self.should_report_to_the_server_now(msg): @@ -376,7 +375,7 @@ The individual messages in the replication stream are presented by consumer = LogicalStreamConsumer() cur.consume_replication_stream(consumer, decode=True) - The *msg* objects passed to the ``consume()`` method are instances of + The *msg* objects passed to ``consume()`` are instances of `ReplicationMessage` class. After storing certain amount of messages' data reliably, the client @@ -401,11 +400,10 @@ The individual messages in the replication stream are presented by .. method:: stop_replication() - This method can be called on synchronous connections from the - ``consume()`` method of a ``consumer`` object in order to break out of - the endless loop in `consume_replication_stream()`. If called on - asynchronous connection or outside of the consume loop, this method - raises an error. + This method can be called on synchronous connection from the + ``consume()`` callable in order to break out of the endless loop in + `consume_replication_stream()`. If called on asynchronous connection + or when replication is not in progress, this method raises an error. .. method:: send_replication_feedback(write_lsn=0, flush_lsn=0, apply_lsn=0, reply=False) @@ -490,11 +488,14 @@ The individual messages in the replication stream are presented by An actual example of asynchronous operation might look like this:: + def consume(msg): + ... + keepalive_interval = 10.0 while True: msg = cur.read_replication_message() if msg: - consumer.consume(msg) + consume(msg) else: timeout = keepalive_interval - (datetime.now() - cur.replication_io_timestamp).total_seconds() if timeout > 0: diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 5dd08cc9..a4581495 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1622,13 +1622,15 @@ psyco_curs_start_replication_expert(cursorObject *self, PyObject *args, PyObject 1 /* no_result */, 1 /* no_begin */) >= 0) { res = Py_None; Py_INCREF(res); + + self->repl_started = 1; } return res; } #define psyco_curs_stop_replication_doc \ -"stop_replication() -- Set flag to break out of endless loop in start_replication() on sync connection." +"stop_replication() -- Set flag to break out of the endless loop in consume_replication_stream()." static PyObject * psyco_curs_stop_replication(cursorObject *self) @@ -1652,13 +1654,13 @@ psyco_curs_stop_replication(cursorObject *self) static PyObject * psyco_curs_consume_replication_stream(cursorObject *self, PyObject *args, PyObject *kwargs) { - PyObject *consumer = NULL, *res = NULL; + PyObject *consume = NULL, *res = NULL; int decode = 0; double keepalive_interval = 10; - static char *kwlist[] = {"consumer", "decode", "keepalive_interval", NULL}; + static char *kwlist[] = {"consume", "decode", "keepalive_interval", NULL}; if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O|id", kwlist, - &consumer, &decode, &keepalive_interval)) { + &consume, &decode, &keepalive_interval)) { return NULL; } @@ -1674,9 +1676,7 @@ psyco_curs_consume_replication_stream(cursorObject *self, PyObject *args, PyObje return NULL; } - self->repl_started = 1; - - if (pq_copy_both(self, consumer, decode, keepalive_interval) >= 0) { + if (pq_copy_both(self, consume, decode, keepalive_interval) >= 0) { res = Py_None; Py_INCREF(res); } @@ -1709,7 +1709,7 @@ static PyObject * curs_flush_replication_feedback(cursorObject *self, int reply) { if (!(self->repl_feedback_pending || reply)) - Py_RETURN_FALSE; + Py_RETURN_TRUE; if (pq_send_replication_feedback(self, reply)) { self->repl_feedback_pending = 0; diff --git a/psycopg/pqpath.c b/psycopg/pqpath.c index 4f1427de..a42c9a1a 100644 --- a/psycopg/pqpath.c +++ b/psycopg/pqpath.c @@ -1723,7 +1723,7 @@ pq_send_replication_feedback(cursorObject* curs, int reply_requested) manages to send keepalive messages to the server as needed. */ int -pq_copy_both(cursorObject *curs, PyObject *consumer, int decode, double keepalive_interval) +pq_copy_both(cursorObject *curs, PyObject *consume, int decode, double keepalive_interval) { PyObject *msg, *tmp = NULL; PyObject *consume_func = NULL; @@ -1732,8 +1732,8 @@ pq_copy_both(cursorObject *curs, PyObject *consumer, int decode, double keepaliv fd_set fds; struct timeval keep_intr, curr_time, ping_time, timeout; - if (!(consume_func = PyObject_GetAttrString(consumer, "consume"))) { - Dprintf("pq_copy_both: can't get o.consume"); + if (!(consume_func = PyObject_GetAttrString(consume, "__call__"))) { + Dprintf("pq_copy_both: expected callable consume object"); goto exit; } @@ -1743,7 +1743,7 @@ pq_copy_both(cursorObject *curs, PyObject *consumer, int decode, double keepaliv keep_intr.tv_sec = (int)keepalive_interval; keep_intr.tv_usec = (keepalive_interval - keep_intr.tv_sec)*1.0e6; - while (1) { + while (!curs->repl_stop) { msg = pq_read_replication_message(curs, decode); if (!msg) { goto exit; diff --git a/tests/test_connection.py b/tests/test_connection.py index 18f1ff3e..e2b0da30 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1206,7 +1206,11 @@ class ReplicationTest(ConnectingTestCase): self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication) cur.start_replication() - self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication) + cur.stop_replication() # doesn't raise now + + def consume(msg): + pass + cur.consume_replication_stream(consume) # should return at once def test_suite():