From d14fea31a33488a1f62a45a8a87109d5be678a72 Mon Sep 17 00:00:00 2001 From: Oleksandr Shulgin Date: Thu, 15 Oct 2015 12:56:21 +0200 Subject: [PATCH] Use quote_ident from psycopg2.extensions --- lib/extras.py | 18 +++++++----------- tests/test_connection.py | 14 ++++++++++++++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/extras.py b/lib/extras.py index e0fd8ef1..f411a4d0 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -40,7 +40,7 @@ from psycopg2 import extensions as _ext from psycopg2.extensions import cursor as _cursor from psycopg2.extensions import connection as _connection from psycopg2.extensions import replicationMessage as ReplicationMessage -from psycopg2.extensions import adapt as _A +from psycopg2.extensions import adapt as _A, quote_ident from psycopg2.extensions import b @@ -484,10 +484,6 @@ class ReplicationConnectionBase(_connection): if self.cursor_factory is None: self.cursor_factory = ReplicationCursor - def quote_ident(self, ident): - # FIXME: use PQescapeIdentifier or psycopg_escape_identifier_easy, somehow - return '"%s"' % ident.replace('"', '""') - class LogicalReplicationConnection(ReplicationConnectionBase): @@ -509,7 +505,7 @@ class ReplicationCursor(_cursor): def create_replication_slot(self, slot_name, slot_type=None, output_plugin=None): """Create streaming replication slot.""" - command = "CREATE_REPLICATION_SLOT %s " % self.connection.quote_ident(slot_name) + command = "CREATE_REPLICATION_SLOT %s " % quote_ident(slot_name, self) if slot_type is None: slot_type = self.connection.replication_type @@ -518,7 +514,7 @@ class ReplicationCursor(_cursor): if output_plugin is None: raise psycopg2.ProgrammingError("output plugin name is required to create logical replication slot") - command += "%s %s" % (slot_type, self.connection.quote_ident(output_plugin)) + command += "%s %s" % (slot_type, quote_ident(output_plugin, self)) elif slot_type == REPLICATION_PHYSICAL: if output_plugin is not None: @@ -534,7 +530,7 @@ class ReplicationCursor(_cursor): def drop_replication_slot(self, slot_name): """Drop streaming replication slot.""" - command = "DROP_REPLICATION_SLOT %s" % self.connection.quote_ident(slot_name) + command = "DROP_REPLICATION_SLOT %s" % quote_ident(slot_name, self) self.execute(command) def start_replication(self, slot_name=None, slot_type=None, start_lsn=0, @@ -548,7 +544,7 @@ class ReplicationCursor(_cursor): if slot_type == REPLICATION_LOGICAL: if slot_name: - command += "SLOT %s " % self.connection.quote_ident(slot_name) + command += "SLOT %s " % quote_ident(slot_name, self) else: raise psycopg2.ProgrammingError("slot name is required for logical replication") @@ -556,7 +552,7 @@ class ReplicationCursor(_cursor): elif slot_type == REPLICATION_PHYSICAL: if slot_name: - command += "SLOT %s " % self.connection.quote_ident(slot_name) + command += "SLOT %s " % quote_ident(slot_name, self) # don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX else: @@ -584,7 +580,7 @@ class ReplicationCursor(_cursor): for k,v in options.iteritems(): if not command.endswith('('): command += ", " - command += "%s %s" % (self.connection.quote_ident(k), _A(str(v))) + command += "%s %s" % (quote_ident(k, self), _A(str(v))) command += ")" self.start_replication_expert(command) diff --git a/tests/test_connection.py b/tests/test_connection.py index e2b0da30..eeeaa845 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1212,6 +1212,20 @@ class ReplicationTest(ConnectingTestCase): pass cur.consume_replication_stream(consume) # should return at once + @skip_before_postgres(9, 4) # slots require 9.4 + def test_create_replication_slot(self): + import psycopg2.extras + conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection) + if conn is None: return + cur = conn.cursor() + + slot = "test_slot1" + try: + cur.create_replication_slot(slot) + self.assertRaises(psycopg2.ProgrammingError, cur.create_replication_slot, slot) + finally: + cur.drop_replication_slot(slot) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)