diff --git a/lib/extras.py b/lib/extras.py index f411a4d0..dc2d5e65 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -449,7 +449,7 @@ class ReplicationConnectionBase(_connection): classes. Uses `ReplicationCursor` automatically. """ - def __init__(self, dsn, **kwargs): + def __init__(self, *args, **kwargs): """ Initializes a replication connection by adding appropriate parameters to the provided DSN and tweaking the connection @@ -466,7 +466,7 @@ class ReplicationConnectionBase(_connection): else: raise psycopg2.ProgrammingError("unrecognized replication type: %s" % self.replication_type) - items = _ext.parse_dsn(dsn) + items = _ext.parse_dsn(args[0]) # we add an appropriate replication keyword parameter, unless # user has specified one explicitly in the DSN @@ -475,7 +475,8 @@ class ReplicationConnectionBase(_connection): dsn = " ".join(["%s=%s" % (k, psycopg2._param_escape(str(v))) for (k, v) in items.iteritems()]) - super(ReplicationConnectionBase, self).__init__(dsn, **kwargs) + args = [dsn] + list(args[1:]) # async is the possible 2nd arg + super(ReplicationConnectionBase, self).__init__(*args, **kwargs) # prevent auto-issued BEGIN statements if not self.async: diff --git a/tests/test_async.py b/tests/test_async.py index d40b9c3e..e0bca7d5 100755 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -29,7 +29,6 @@ import psycopg2 from psycopg2 import extensions import time -import select import StringIO from testutils import ConnectingTestCase @@ -66,21 +65,6 @@ class AsyncTests(ConnectingTestCase): )''') self.wait(curs) - def wait(self, cur_or_conn): - pollable = cur_or_conn - if not hasattr(pollable, 'poll'): - pollable = cur_or_conn.connection - while True: - state = pollable.poll() - if state == psycopg2.extensions.POLL_OK: - break - elif state == psycopg2.extensions.POLL_READ: - select.select([pollable], [], [], 10) - elif state == psycopg2.extensions.POLL_WRITE: - select.select([], [pollable], [], 10) - else: - raise Exception("Unexpected result from poll: %r", state) - def test_connection_setup(self): cur = self.conn.cursor() sync_cur = self.sync_conn.cursor() diff --git a/tests/test_connection.py b/tests/test_connection.py index eeeaa845..568f09ed 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1178,55 +1178,6 @@ class AutocommitTests(ConnectingTestCase): self.assertEqual(cur.fetchone()[0], 'on') -class ReplicationTest(ConnectingTestCase): - @skip_before_postgres(9, 0) - def test_physical_replication_connection(self): - import psycopg2.extras - conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection) - if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() - - @skip_before_postgres(9, 4) - def test_logical_replication_connection(self): - import psycopg2.extras - conn = self.repl_connect(connection_factory=psycopg2.extras.LogicalReplicationConnection) - if conn is None: return - cur = conn.cursor() - cur.execute("IDENTIFY_SYSTEM") - cur.fetchall() - - @skip_before_postgres(9, 0) - def test_stop_replication_raises(self): - import psycopg2.extras - conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection) - if conn is None: return - cur = conn.cursor() - self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication) - - cur.start_replication() - cur.stop_replication() # doesn't raise now - - def consume(msg): - pass - cur.consume_replication_stream(consume) # should return at once - - @skip_before_postgres(9, 4) # slots require 9.4 - def test_create_replication_slot(self): - import psycopg2.extras - conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection) - if conn is None: return - cur = conn.cursor() - - slot = "test_slot1" - try: - cur.create_replication_slot(slot) - self.assertRaises(psycopg2.ProgrammingError, cur.create_replication_slot, slot) - finally: - cur.drop_replication_slot(slot) - - def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) diff --git a/tests/test_replication.py b/tests/test_replication.py new file mode 100644 index 00000000..231bcd08 --- /dev/null +++ b/tests/test_replication.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python + +# test_replication.py - unit test for replication protocol +# +# Copyright (C) 2015 Daniele Varrazzo +# +# psycopg2 is free software: you can redistribute it and/or modify it +# under the terms of the GNU Lesser General Public License as published +# by the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# In addition, as a special exception, the copyright holders give +# permission to link this program with the OpenSSL library (or with +# modified versions of OpenSSL that use the same license as OpenSSL), +# and distribute linked combinations including the two. +# +# You must obey the GNU Lesser General Public License in all respects for +# all of the code used other than OpenSSL. +# +# psycopg2 is distributed in the hope that it will be useful, but WITHOUT +# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or +# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public +# License for more details. + +import psycopg2 +import psycopg2.extensions +from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection + +from testutils import unittest +from testutils import skip_before_postgres +from testutils import ConnectingTestCase + + +class ReplicationTestCase(ConnectingTestCase): + def setUp(self): + super(ReplicationTestCase, self).setUp() + self._slots = [] + + def tearDown(self): + # first close all connections, as they might keep the slot(s) active + super(ReplicationTestCase, self).tearDown() + + if self._slots: + kill_conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) + if kill_conn: + kill_cur = kill_conn.cursor() + for slot in self._slots: + kill_cur.drop_replication_slot(slot) + kill_conn.close() + + def create_replication_slot(self, cur, slot_name, **kwargs): + cur.create_replication_slot(slot_name, **kwargs) + self._slots.append(slot_name) + + def drop_replication_slot(self, cur, slot_name): + cur.drop_replication_slot(slot_name) + self._slots.remove(slot_name) + + +class ReplicationTest(ReplicationTestCase): + @skip_before_postgres(9, 0) + def test_physical_replication_connection(self): + conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) + if conn is None: return + cur = conn.cursor() + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() + + @skip_before_postgres(9, 4) + def test_logical_replication_connection(self): + conn = self.repl_connect(connection_factory=LogicalReplicationConnection) + if conn is None: return + cur = conn.cursor() + cur.execute("IDENTIFY_SYSTEM") + cur.fetchall() + + @skip_before_postgres(9, 0) + def test_stop_replication_raises(self): + conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) + if conn is None: return + cur = conn.cursor() + self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication) + + cur.start_replication() + cur.stop_replication() # doesn't raise now + + def consume(msg): + pass + cur.consume_replication_stream(consume) # should return at once + + @skip_before_postgres(9, 4) # slots require 9.4 + def test_create_replication_slot(self): + conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) + if conn is None: return + cur = conn.cursor() + + slot = "test_slot1" + + self.create_replication_slot(cur, slot) + self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur, slot) + + +class AsyncReplicationTest(ReplicationTestCase): + @skip_before_postgres(9, 4) + def test_async_replication(self): + conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1) + if conn is None: return + self.wait(conn) + cur = conn.cursor() + + slot = "test_slot1" + self.create_replication_slot(cur, slot, output_plugin='test_decoding') + self.wait(cur) + + cur.start_replication(slot) + self.wait(cur) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + +if __name__ == "__main__": + unittest.main() diff --git a/tests/testutils.py b/tests/testutils.py index 76671d99..5f4493f2 100644 --- a/tests/testutils.py +++ b/tests/testutils.py @@ -27,6 +27,7 @@ import os import platform import sys +import select from functools import wraps from testconfig import dsn, repl_dsn @@ -129,7 +130,8 @@ class ConnectingTestCase(unittest.TestCase): except psycopg2.OperationalError, e: return self.skipTest("replication db not configured: %s" % e) - conn.autocommit = True + if not conn.async: + conn.autocommit = True return conn def _get_conn(self): @@ -143,6 +145,23 @@ class ConnectingTestCase(unittest.TestCase): conn = property(_get_conn, _set_conn) + # for use with async connections only + def wait(self, cur_or_conn): + import psycopg2.extensions + pollable = cur_or_conn + if not hasattr(pollable, 'poll'): + pollable = cur_or_conn.connection + while True: + state = pollable.poll() + if state == psycopg2.extensions.POLL_OK: + break + elif state == psycopg2.extensions.POLL_READ: + select.select([pollable], [], [], 10) + elif state == psycopg2.extensions.POLL_WRITE: + select.select([], [pollable], [], 10) + else: + raise Exception("Unexpected result from poll: %r", state) + def decorate_all_tests(cls, *decorators): """