diff --git a/tests/test_replication.py b/tests/test_replication.py index cd1321ae..5c029c88 100644 --- a/tests/test_replication.py +++ b/tests/test_replication.py @@ -27,6 +27,7 @@ import psycopg2.extensions from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection from psycopg2.extras import StopReplication +import testconfig from testutils import unittest from testutils import skip_before_postgres from testutils import ConnectingTestCase @@ -34,10 +35,12 @@ from testutils import ConnectingTestCase class ReplicationTestCase(ConnectingTestCase): def setUp(self): - from testconfig import repl_dsn - if not repl_dsn: + if not testconfig.repl_dsn: self.skipTest("replication tests disabled by default") + super(ReplicationTestCase, self).setUp() + + self.slot = testconfig.repl_slot self._slots = [] def tearDown(self): @@ -52,14 +55,27 @@ class ReplicationTestCase(ConnectingTestCase): kill_cur.drop_replication_slot(slot) kill_conn.close() - def create_replication_slot(self, cur, slot_name, **kwargs): + def create_replication_slot(self, cur, slot_name=testconfig.repl_slot, **kwargs): cur.create_replication_slot(slot_name, **kwargs) self._slots.append(slot_name) - def drop_replication_slot(self, cur, slot_name): + def drop_replication_slot(self, cur, slot_name=testconfig.repl_slot): cur.drop_replication_slot(slot_name) self._slots.remove(slot_name) + # generate some events for our replication stream + def make_replication_events(self): + conn = self.connect() + if conn is None: return + cur = conn.cursor() + + try: + cur.execute("DROP TABLE dummy1") + except psycopg2.ProgrammingError: + conn.rollback() + cur.execute("CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") + conn.commit() + class ReplicationTest(ReplicationTestCase): @skip_before_postgres(9, 0) @@ -84,10 +100,8 @@ class ReplicationTest(ReplicationTestCase): if conn is None: return cur = conn.cursor() - slot = "test_slot1" - - self.create_replication_slot(cur, slot) - self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur, slot) + self.create_replication_slot(cur) + self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur) @skip_before_postgres(9, 4) # slots require 9.4 def test_start_on_missing_replication_slot(self): @@ -95,12 +109,10 @@ class ReplicationTest(ReplicationTestCase): if conn is None: return cur = conn.cursor() - slot = "test_slot1" + self.assertRaises(psycopg2.ProgrammingError, cur.start_replication, self.slot) - self.assertRaises(psycopg2.ProgrammingError, cur.start_replication, slot) - - self.create_replication_slot(cur, slot) - cur.start_replication(slot) + self.create_replication_slot(cur) + cur.start_replication(self.slot) @skip_before_postgres(9, 4) # slots require 9.4 def test_stop_replication(self): @@ -108,46 +120,47 @@ class ReplicationTest(ReplicationTestCase): if conn is None: return cur = conn.cursor() - slot = "test_slot1" + self.create_replication_slot(cur, output_plugin='test_decoding') - self.create_replication_slot(cur, slot, output_plugin='test_decoding') + self.make_replication_events() - self.make_replication_event() - - cur.start_replication(slot) + cur.start_replication(self.slot) def consume(msg): raise StopReplication() self.assertRaises(StopReplication, cur.consume_replication_stream, consume) - # generate an event for our replication stream - def make_replication_event(self): - conn = self.connect() - if conn is None: return - cur = conn.cursor() - - try: - cur.execute("DROP TABLE dummy1") - except psycopg2.ProgrammingError: - conn.rollback() - cur.execute("CREATE TABLE dummy1()") - conn.commit() - class AsyncReplicationTest(ReplicationTestCase): - @skip_before_postgres(9, 4) + @skip_before_postgres(9, 4) # slots require 9.4 def test_async_replication(self): conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1) if conn is None: return self.wait(conn) cur = conn.cursor() - slot = "test_slot1" - self.create_replication_slot(cur, slot, output_plugin='test_decoding') + self.create_replication_slot(cur, output_plugin='test_decoding') self.wait(cur) - cur.start_replication(slot) + cur.start_replication(self.slot) self.wait(cur) + self.make_replication_events() + + self.msg_count = 0 + def consume(msg): + self.msg_count += 1 + if self.msg_count > 3: + raise StopReplication() + + def process_stream(): + from select import select + while True: + msg = cur.read_replication_message() + if msg: + consume(msg) + else: + select([cur], [], []) + self.assertRaises(StopReplication, process_stream) def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) diff --git a/tests/testconfig.py b/tests/testconfig.py index 841eaf1c..82b48a39 100644 --- a/tests/testconfig.py +++ b/tests/testconfig.py @@ -38,3 +38,5 @@ if dbpass is not None: repl_dsn = os.environ.get('PSYCOPG2_TEST_REPL_DSN', None) if repl_dsn == '': repl_dsn = dsn + +repl_slot = os.environ.get('PSYCOPG2_TEST_REPL_SLOT', 'psycopg2_test_slot')