Improve async replication test.

This commit is contained in:
Oleksandr Shulgin 2015-10-19 17:02:18 +02:00
parent 4ab7cf0157
commit 7aea2cef6e
2 changed files with 50 additions and 35 deletions

View File

@ -27,6 +27,7 @@ import psycopg2.extensions
from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection
from psycopg2.extras import StopReplication from psycopg2.extras import StopReplication
import testconfig
from testutils import unittest from testutils import unittest
from testutils import skip_before_postgres from testutils import skip_before_postgres
from testutils import ConnectingTestCase from testutils import ConnectingTestCase
@ -34,10 +35,12 @@ from testutils import ConnectingTestCase
class ReplicationTestCase(ConnectingTestCase): class ReplicationTestCase(ConnectingTestCase):
def setUp(self): def setUp(self):
from testconfig import repl_dsn if not testconfig.repl_dsn:
if not repl_dsn:
self.skipTest("replication tests disabled by default") self.skipTest("replication tests disabled by default")
super(ReplicationTestCase, self).setUp() super(ReplicationTestCase, self).setUp()
self.slot = testconfig.repl_slot
self._slots = [] self._slots = []
def tearDown(self): def tearDown(self):
@ -52,14 +55,27 @@ class ReplicationTestCase(ConnectingTestCase):
kill_cur.drop_replication_slot(slot) kill_cur.drop_replication_slot(slot)
kill_conn.close() kill_conn.close()
def create_replication_slot(self, cur, slot_name, **kwargs): def create_replication_slot(self, cur, slot_name=testconfig.repl_slot, **kwargs):
cur.create_replication_slot(slot_name, **kwargs) cur.create_replication_slot(slot_name, **kwargs)
self._slots.append(slot_name) self._slots.append(slot_name)
def drop_replication_slot(self, cur, slot_name): def drop_replication_slot(self, cur, slot_name=testconfig.repl_slot):
cur.drop_replication_slot(slot_name) cur.drop_replication_slot(slot_name)
self._slots.remove(slot_name) self._slots.remove(slot_name)
# generate some events for our replication stream
def make_replication_events(self):
conn = self.connect()
if conn is None: return
cur = conn.cursor()
try:
cur.execute("DROP TABLE dummy1")
except psycopg2.ProgrammingError:
conn.rollback()
cur.execute("CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
conn.commit()
class ReplicationTest(ReplicationTestCase): class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 0) @skip_before_postgres(9, 0)
@ -84,10 +100,8 @@ class ReplicationTest(ReplicationTestCase):
if conn is None: return if conn is None: return
cur = conn.cursor() cur = conn.cursor()
slot = "test_slot1" self.create_replication_slot(cur)
self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur)
self.create_replication_slot(cur, slot)
self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur, slot)
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_start_on_missing_replication_slot(self): def test_start_on_missing_replication_slot(self):
@ -95,12 +109,10 @@ class ReplicationTest(ReplicationTestCase):
if conn is None: return if conn is None: return
cur = conn.cursor() cur = conn.cursor()
slot = "test_slot1" self.assertRaises(psycopg2.ProgrammingError, cur.start_replication, self.slot)
self.assertRaises(psycopg2.ProgrammingError, cur.start_replication, slot) self.create_replication_slot(cur)
cur.start_replication(self.slot)
self.create_replication_slot(cur, slot)
cur.start_replication(slot)
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_stop_replication(self): def test_stop_replication(self):
@ -108,46 +120,47 @@ class ReplicationTest(ReplicationTestCase):
if conn is None: return if conn is None: return
cur = conn.cursor() cur = conn.cursor()
slot = "test_slot1" self.create_replication_slot(cur, output_plugin='test_decoding')
self.create_replication_slot(cur, slot, output_plugin='test_decoding') self.make_replication_events()
self.make_replication_event() cur.start_replication(self.slot)
cur.start_replication(slot)
def consume(msg): def consume(msg):
raise StopReplication() raise StopReplication()
self.assertRaises(StopReplication, cur.consume_replication_stream, consume) self.assertRaises(StopReplication, cur.consume_replication_stream, consume)
# generate an event for our replication stream
def make_replication_event(self):
conn = self.connect()
if conn is None: return
cur = conn.cursor()
try:
cur.execute("DROP TABLE dummy1")
except psycopg2.ProgrammingError:
conn.rollback()
cur.execute("CREATE TABLE dummy1()")
conn.commit()
class AsyncReplicationTest(ReplicationTestCase): class AsyncReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) @skip_before_postgres(9, 4) # slots require 9.4
def test_async_replication(self): def test_async_replication(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1) conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1)
if conn is None: return if conn is None: return
self.wait(conn) self.wait(conn)
cur = conn.cursor() cur = conn.cursor()
slot = "test_slot1" self.create_replication_slot(cur, output_plugin='test_decoding')
self.create_replication_slot(cur, slot, output_plugin='test_decoding')
self.wait(cur) self.wait(cur)
cur.start_replication(slot) cur.start_replication(self.slot)
self.wait(cur) self.wait(cur)
self.make_replication_events()
self.msg_count = 0
def consume(msg):
self.msg_count += 1
if self.msg_count > 3:
raise StopReplication()
def process_stream():
from select import select
while True:
msg = cur.read_replication_message()
if msg:
consume(msg)
else:
select([cur], [], [])
self.assertRaises(StopReplication, process_stream)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)

View File

@ -38,3 +38,5 @@ if dbpass is not None:
repl_dsn = os.environ.get('PSYCOPG2_TEST_REPL_DSN', None) repl_dsn = os.environ.get('PSYCOPG2_TEST_REPL_DSN', None)
if repl_dsn == '': if repl_dsn == '':
repl_dsn = dsn repl_dsn = dsn
repl_slot = os.environ.get('PSYCOPG2_TEST_REPL_SLOT', 'psycopg2_test_slot')