Fix async replication and test.

This commit is contained in:
Oleksandr Shulgin 2015-10-15 18:01:43 +02:00
parent d14fea31a3
commit cf4f2411bf
5 changed files with 147 additions and 69 deletions

View File

@ -449,7 +449,7 @@ class ReplicationConnectionBase(_connection):
classes. Uses `ReplicationCursor` automatically.
"""
def __init__(self, dsn, **kwargs):
def __init__(self, *args, **kwargs):
"""
Initializes a replication connection by adding appropriate
parameters to the provided DSN and tweaking the connection
@ -466,7 +466,7 @@ class ReplicationConnectionBase(_connection):
else:
raise psycopg2.ProgrammingError("unrecognized replication type: %s" % self.replication_type)
items = _ext.parse_dsn(dsn)
items = _ext.parse_dsn(args[0])
# we add an appropriate replication keyword parameter, unless
# user has specified one explicitly in the DSN
@ -475,7 +475,8 @@ class ReplicationConnectionBase(_connection):
dsn = " ".join(["%s=%s" % (k, psycopg2._param_escape(str(v)))
for (k, v) in items.iteritems()])
super(ReplicationConnectionBase, self).__init__(dsn, **kwargs)
args = [dsn] + list(args[1:]) # async is the possible 2nd arg
super(ReplicationConnectionBase, self).__init__(*args, **kwargs)
# prevent auto-issued BEGIN statements
if not self.async:

View File

@ -29,7 +29,6 @@ import psycopg2
from psycopg2 import extensions
import time
import select
import StringIO
from testutils import ConnectingTestCase
@ -66,21 +65,6 @@ class AsyncTests(ConnectingTestCase):
)''')
self.wait(curs)
def wait(self, cur_or_conn):
pollable = cur_or_conn
if not hasattr(pollable, 'poll'):
pollable = cur_or_conn.connection
while True:
state = pollable.poll()
if state == psycopg2.extensions.POLL_OK:
break
elif state == psycopg2.extensions.POLL_READ:
select.select([pollable], [], [], 10)
elif state == psycopg2.extensions.POLL_WRITE:
select.select([], [pollable], [], 10)
else:
raise Exception("Unexpected result from poll: %r", state)
def test_connection_setup(self):
cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor()

View File

@ -1178,55 +1178,6 @@ class AutocommitTests(ConnectingTestCase):
self.assertEqual(cur.fetchone()[0], 'on')
class ReplicationTest(ConnectingTestCase):
@skip_before_postgres(9, 0)
def test_physical_replication_connection(self):
import psycopg2.extras
conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 4)
def test_logical_replication_connection(self):
import psycopg2.extras
conn = self.repl_connect(connection_factory=psycopg2.extras.LogicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 0)
def test_stop_replication_raises(self):
import psycopg2.extras
conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication)
cur.start_replication()
cur.stop_replication() # doesn't raise now
def consume(msg):
pass
cur.consume_replication_stream(consume) # should return at once
@skip_before_postgres(9, 4) # slots require 9.4
def test_create_replication_slot(self):
import psycopg2.extras
conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
slot = "test_slot1"
try:
cur.create_replication_slot(slot)
self.assertRaises(psycopg2.ProgrammingError, cur.create_replication_slot, slot)
finally:
cur.drop_replication_slot(slot)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)

123
tests/test_replication.py Normal file
View File

@ -0,0 +1,123 @@
#!/usr/bin/env python
# test_replication.py - unit test for replication protocol
#
# Copyright (C) 2015 Daniele Varrazzo <daniele.varrazzo@gmail.com>
#
# psycopg2 is free software: you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# In addition, as a special exception, the copyright holders give
# permission to link this program with the OpenSSL library (or with
# modified versions of OpenSSL that use the same license as OpenSSL),
# and distribute linked combinations including the two.
#
# You must obey the GNU Lesser General Public License in all respects for
# all of the code used other than OpenSSL.
#
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import psycopg2
import psycopg2.extensions
from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection
from testutils import unittest
from testutils import skip_before_postgres
from testutils import ConnectingTestCase
class ReplicationTestCase(ConnectingTestCase):
def setUp(self):
super(ReplicationTestCase, self).setUp()
self._slots = []
def tearDown(self):
# first close all connections, as they might keep the slot(s) active
super(ReplicationTestCase, self).tearDown()
if self._slots:
kill_conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if kill_conn:
kill_cur = kill_conn.cursor()
for slot in self._slots:
kill_cur.drop_replication_slot(slot)
kill_conn.close()
def create_replication_slot(self, cur, slot_name, **kwargs):
cur.create_replication_slot(slot_name, **kwargs)
self._slots.append(slot_name)
def drop_replication_slot(self, cur, slot_name):
cur.drop_replication_slot(slot_name)
self._slots.remove(slot_name)
class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 0)
def test_physical_replication_connection(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 4)
def test_logical_replication_connection(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 0)
def test_stop_replication_raises(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication)
cur.start_replication()
cur.stop_replication() # doesn't raise now
def consume(msg):
pass
cur.consume_replication_stream(consume) # should return at once
@skip_before_postgres(9, 4) # slots require 9.4
def test_create_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
slot = "test_slot1"
self.create_replication_slot(cur, slot)
self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur, slot)
class AsyncReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4)
def test_async_replication(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1)
if conn is None: return
self.wait(conn)
cur = conn.cursor()
slot = "test_slot1"
self.create_replication_slot(cur, slot, output_plugin='test_decoding')
self.wait(cur)
cur.start_replication(slot)
self.wait(cur)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

View File

@ -27,6 +27,7 @@
import os
import platform
import sys
import select
from functools import wraps
from testconfig import dsn, repl_dsn
@ -129,7 +130,8 @@ class ConnectingTestCase(unittest.TestCase):
except psycopg2.OperationalError, e:
return self.skipTest("replication db not configured: %s" % e)
conn.autocommit = True
if not conn.async:
conn.autocommit = True
return conn
def _get_conn(self):
@ -143,6 +145,23 @@ class ConnectingTestCase(unittest.TestCase):
conn = property(_get_conn, _set_conn)
# for use with async connections only
def wait(self, cur_or_conn):
import psycopg2.extensions
pollable = cur_or_conn
if not hasattr(pollable, 'poll'):
pollable = cur_or_conn.connection
while True:
state = pollable.poll()
if state == psycopg2.extensions.POLL_OK:
break
elif state == psycopg2.extensions.POLL_READ:
select.select([pollable], [], [], 10)
elif state == psycopg2.extensions.POLL_WRITE:
select.select([], [pollable], [], 10)
else:
raise Exception("Unexpected result from poll: %r", state)
def decorate_all_tests(cls, *decorators):
"""