mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-22 00:46:33 +03:00
Fix async replication and test.
This commit is contained in:
parent
d14fea31a3
commit
cf4f2411bf
|
@ -449,7 +449,7 @@ class ReplicationConnectionBase(_connection):
|
|||
classes. Uses `ReplicationCursor` automatically.
|
||||
"""
|
||||
|
||||
def __init__(self, dsn, **kwargs):
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes a replication connection by adding appropriate
|
||||
parameters to the provided DSN and tweaking the connection
|
||||
|
@ -466,7 +466,7 @@ class ReplicationConnectionBase(_connection):
|
|||
else:
|
||||
raise psycopg2.ProgrammingError("unrecognized replication type: %s" % self.replication_type)
|
||||
|
||||
items = _ext.parse_dsn(dsn)
|
||||
items = _ext.parse_dsn(args[0])
|
||||
|
||||
# we add an appropriate replication keyword parameter, unless
|
||||
# user has specified one explicitly in the DSN
|
||||
|
@ -475,7 +475,8 @@ class ReplicationConnectionBase(_connection):
|
|||
dsn = " ".join(["%s=%s" % (k, psycopg2._param_escape(str(v)))
|
||||
for (k, v) in items.iteritems()])
|
||||
|
||||
super(ReplicationConnectionBase, self).__init__(dsn, **kwargs)
|
||||
args = [dsn] + list(args[1:]) # async is the possible 2nd arg
|
||||
super(ReplicationConnectionBase, self).__init__(*args, **kwargs)
|
||||
|
||||
# prevent auto-issued BEGIN statements
|
||||
if not self.async:
|
||||
|
|
|
@ -29,7 +29,6 @@ import psycopg2
|
|||
from psycopg2 import extensions
|
||||
|
||||
import time
|
||||
import select
|
||||
import StringIO
|
||||
|
||||
from testutils import ConnectingTestCase
|
||||
|
@ -66,21 +65,6 @@ class AsyncTests(ConnectingTestCase):
|
|||
)''')
|
||||
self.wait(curs)
|
||||
|
||||
def wait(self, cur_or_conn):
|
||||
pollable = cur_or_conn
|
||||
if not hasattr(pollable, 'poll'):
|
||||
pollable = cur_or_conn.connection
|
||||
while True:
|
||||
state = pollable.poll()
|
||||
if state == psycopg2.extensions.POLL_OK:
|
||||
break
|
||||
elif state == psycopg2.extensions.POLL_READ:
|
||||
select.select([pollable], [], [], 10)
|
||||
elif state == psycopg2.extensions.POLL_WRITE:
|
||||
select.select([], [pollable], [], 10)
|
||||
else:
|
||||
raise Exception("Unexpected result from poll: %r", state)
|
||||
|
||||
def test_connection_setup(self):
|
||||
cur = self.conn.cursor()
|
||||
sync_cur = self.sync_conn.cursor()
|
||||
|
|
|
@ -1178,55 +1178,6 @@ class AutocommitTests(ConnectingTestCase):
|
|||
self.assertEqual(cur.fetchone()[0], 'on')
|
||||
|
||||
|
||||
class ReplicationTest(ConnectingTestCase):
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_physical_replication_connection(self):
|
||||
import psycopg2.extras
|
||||
conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
cur.execute("IDENTIFY_SYSTEM")
|
||||
cur.fetchall()
|
||||
|
||||
@skip_before_postgres(9, 4)
|
||||
def test_logical_replication_connection(self):
|
||||
import psycopg2.extras
|
||||
conn = self.repl_connect(connection_factory=psycopg2.extras.LogicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
cur.execute("IDENTIFY_SYSTEM")
|
||||
cur.fetchall()
|
||||
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_stop_replication_raises(self):
|
||||
import psycopg2.extras
|
||||
conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication)
|
||||
|
||||
cur.start_replication()
|
||||
cur.stop_replication() # doesn't raise now
|
||||
|
||||
def consume(msg):
|
||||
pass
|
||||
cur.consume_replication_stream(consume) # should return at once
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
def test_create_replication_slot(self):
|
||||
import psycopg2.extras
|
||||
conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
|
||||
slot = "test_slot1"
|
||||
try:
|
||||
cur.create_replication_slot(slot)
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.create_replication_slot, slot)
|
||||
finally:
|
||||
cur.drop_replication_slot(slot)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
|
|
123
tests/test_replication.py
Normal file
123
tests/test_replication.py
Normal file
|
@ -0,0 +1,123 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# test_replication.py - unit test for replication protocol
|
||||
#
|
||||
# Copyright (C) 2015 Daniele Varrazzo <daniele.varrazzo@gmail.com>
|
||||
#
|
||||
# psycopg2 is free software: you can redistribute it and/or modify it
|
||||
# under the terms of the GNU Lesser General Public License as published
|
||||
# by the Free Software Foundation, either version 3 of the License, or
|
||||
# (at your option) any later version.
|
||||
#
|
||||
# In addition, as a special exception, the copyright holders give
|
||||
# permission to link this program with the OpenSSL library (or with
|
||||
# modified versions of OpenSSL that use the same license as OpenSSL),
|
||||
# and distribute linked combinations including the two.
|
||||
#
|
||||
# You must obey the GNU Lesser General Public License in all respects for
|
||||
# all of the code used other than OpenSSL.
|
||||
#
|
||||
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
||||
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
||||
# License for more details.
|
||||
|
||||
import psycopg2
|
||||
import psycopg2.extensions
|
||||
from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection
|
||||
|
||||
from testutils import unittest
|
||||
from testutils import skip_before_postgres
|
||||
from testutils import ConnectingTestCase
|
||||
|
||||
|
||||
class ReplicationTestCase(ConnectingTestCase):
|
||||
def setUp(self):
|
||||
super(ReplicationTestCase, self).setUp()
|
||||
self._slots = []
|
||||
|
||||
def tearDown(self):
|
||||
# first close all connections, as they might keep the slot(s) active
|
||||
super(ReplicationTestCase, self).tearDown()
|
||||
|
||||
if self._slots:
|
||||
kill_conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
|
||||
if kill_conn:
|
||||
kill_cur = kill_conn.cursor()
|
||||
for slot in self._slots:
|
||||
kill_cur.drop_replication_slot(slot)
|
||||
kill_conn.close()
|
||||
|
||||
def create_replication_slot(self, cur, slot_name, **kwargs):
|
||||
cur.create_replication_slot(slot_name, **kwargs)
|
||||
self._slots.append(slot_name)
|
||||
|
||||
def drop_replication_slot(self, cur, slot_name):
|
||||
cur.drop_replication_slot(slot_name)
|
||||
self._slots.remove(slot_name)
|
||||
|
||||
|
||||
class ReplicationTest(ReplicationTestCase):
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_physical_replication_connection(self):
|
||||
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
cur.execute("IDENTIFY_SYSTEM")
|
||||
cur.fetchall()
|
||||
|
||||
@skip_before_postgres(9, 4)
|
||||
def test_logical_replication_connection(self):
|
||||
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
cur.execute("IDENTIFY_SYSTEM")
|
||||
cur.fetchall()
|
||||
|
||||
@skip_before_postgres(9, 0)
|
||||
def test_stop_replication_raises(self):
|
||||
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
self.assertRaises(psycopg2.ProgrammingError, cur.stop_replication)
|
||||
|
||||
cur.start_replication()
|
||||
cur.stop_replication() # doesn't raise now
|
||||
|
||||
def consume(msg):
|
||||
pass
|
||||
cur.consume_replication_stream(consume) # should return at once
|
||||
|
||||
@skip_before_postgres(9, 4) # slots require 9.4
|
||||
def test_create_replication_slot(self):
|
||||
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
|
||||
if conn is None: return
|
||||
cur = conn.cursor()
|
||||
|
||||
slot = "test_slot1"
|
||||
|
||||
self.create_replication_slot(cur, slot)
|
||||
self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur, slot)
|
||||
|
||||
|
||||
class AsyncReplicationTest(ReplicationTestCase):
|
||||
@skip_before_postgres(9, 4)
|
||||
def test_async_replication(self):
|
||||
conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1)
|
||||
if conn is None: return
|
||||
self.wait(conn)
|
||||
cur = conn.cursor()
|
||||
|
||||
slot = "test_slot1"
|
||||
self.create_replication_slot(cur, slot, output_plugin='test_decoding')
|
||||
self.wait(cur)
|
||||
|
||||
cur.start_replication(slot)
|
||||
self.wait(cur)
|
||||
|
||||
|
||||
def test_suite():
|
||||
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
|
@ -27,6 +27,7 @@
|
|||
import os
|
||||
import platform
|
||||
import sys
|
||||
import select
|
||||
from functools import wraps
|
||||
from testconfig import dsn, repl_dsn
|
||||
|
||||
|
@ -129,7 +130,8 @@ class ConnectingTestCase(unittest.TestCase):
|
|||
except psycopg2.OperationalError, e:
|
||||
return self.skipTest("replication db not configured: %s" % e)
|
||||
|
||||
conn.autocommit = True
|
||||
if not conn.async:
|
||||
conn.autocommit = True
|
||||
return conn
|
||||
|
||||
def _get_conn(self):
|
||||
|
@ -143,6 +145,23 @@ class ConnectingTestCase(unittest.TestCase):
|
|||
|
||||
conn = property(_get_conn, _set_conn)
|
||||
|
||||
# for use with async connections only
|
||||
def wait(self, cur_or_conn):
|
||||
import psycopg2.extensions
|
||||
pollable = cur_or_conn
|
||||
if not hasattr(pollable, 'poll'):
|
||||
pollable = cur_or_conn.connection
|
||||
while True:
|
||||
state = pollable.poll()
|
||||
if state == psycopg2.extensions.POLL_OK:
|
||||
break
|
||||
elif state == psycopg2.extensions.POLL_READ:
|
||||
select.select([pollable], [], [], 10)
|
||||
elif state == psycopg2.extensions.POLL_WRITE:
|
||||
select.select([], [pollable], [], 10)
|
||||
else:
|
||||
raise Exception("Unexpected result from poll: %r", state)
|
||||
|
||||
|
||||
def decorate_all_tests(cls, *decorators):
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user