Use quote_ident from psycopg2.extensions

This commit is contained in:
Oleksandr Shulgin 2015-10-15 12:56:21 +02:00
parent 8e518d4954
commit d14fea31a3
2 changed files with 21 additions and 11 deletions

View File

@ -40,7 +40,7 @@ from psycopg2 import extensions as _ext
from psycopg2.extensions import cursor as _cursor from psycopg2.extensions import cursor as _cursor
from psycopg2.extensions import connection as _connection from psycopg2.extensions import connection as _connection
from psycopg2.extensions import replicationMessage as ReplicationMessage from psycopg2.extensions import replicationMessage as ReplicationMessage
from psycopg2.extensions import adapt as _A from psycopg2.extensions import adapt as _A, quote_ident
from psycopg2.extensions import b from psycopg2.extensions import b
@ -484,10 +484,6 @@ class ReplicationConnectionBase(_connection):
if self.cursor_factory is None: if self.cursor_factory is None:
self.cursor_factory = ReplicationCursor self.cursor_factory = ReplicationCursor
def quote_ident(self, ident):
# FIXME: use PQescapeIdentifier or psycopg_escape_identifier_easy, somehow
return '"%s"' % ident.replace('"', '""')
class LogicalReplicationConnection(ReplicationConnectionBase): class LogicalReplicationConnection(ReplicationConnectionBase):
@ -509,7 +505,7 @@ class ReplicationCursor(_cursor):
def create_replication_slot(self, slot_name, slot_type=None, output_plugin=None): def create_replication_slot(self, slot_name, slot_type=None, output_plugin=None):
"""Create streaming replication slot.""" """Create streaming replication slot."""
command = "CREATE_REPLICATION_SLOT %s " % self.connection.quote_ident(slot_name) command = "CREATE_REPLICATION_SLOT %s " % quote_ident(slot_name, self)
if slot_type is None: if slot_type is None:
slot_type = self.connection.replication_type slot_type = self.connection.replication_type
@ -518,7 +514,7 @@ class ReplicationCursor(_cursor):
if output_plugin is None: if output_plugin is None:
raise psycopg2.ProgrammingError("output plugin name is required to create logical replication slot") raise psycopg2.ProgrammingError("output plugin name is required to create logical replication slot")
command += "%s %s" % (slot_type, self.connection.quote_ident(output_plugin)) command += "%s %s" % (slot_type, quote_ident(output_plugin, self))
elif slot_type == REPLICATION_PHYSICAL: elif slot_type == REPLICATION_PHYSICAL:
if output_plugin is not None: if output_plugin is not None:
@ -534,7 +530,7 @@ class ReplicationCursor(_cursor):
def drop_replication_slot(self, slot_name): def drop_replication_slot(self, slot_name):
"""Drop streaming replication slot.""" """Drop streaming replication slot."""
command = "DROP_REPLICATION_SLOT %s" % self.connection.quote_ident(slot_name) command = "DROP_REPLICATION_SLOT %s" % quote_ident(slot_name, self)
self.execute(command) self.execute(command)
def start_replication(self, slot_name=None, slot_type=None, start_lsn=0, def start_replication(self, slot_name=None, slot_type=None, start_lsn=0,
@ -548,7 +544,7 @@ class ReplicationCursor(_cursor):
if slot_type == REPLICATION_LOGICAL: if slot_type == REPLICATION_LOGICAL:
if slot_name: if slot_name:
command += "SLOT %s " % self.connection.quote_ident(slot_name) command += "SLOT %s " % quote_ident(slot_name, self)
else: else:
raise psycopg2.ProgrammingError("slot name is required for logical replication") raise psycopg2.ProgrammingError("slot name is required for logical replication")
@ -556,7 +552,7 @@ class ReplicationCursor(_cursor):
elif slot_type == REPLICATION_PHYSICAL: elif slot_type == REPLICATION_PHYSICAL:
if slot_name: if slot_name:
command += "SLOT %s " % self.connection.quote_ident(slot_name) command += "SLOT %s " % quote_ident(slot_name, self)
# don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX # don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX
else: else:
@ -584,7 +580,7 @@ class ReplicationCursor(_cursor):
for k,v in options.iteritems(): for k,v in options.iteritems():
if not command.endswith('('): if not command.endswith('('):
command += ", " command += ", "
command += "%s %s" % (self.connection.quote_ident(k), _A(str(v))) command += "%s %s" % (quote_ident(k, self), _A(str(v)))
command += ")" command += ")"
self.start_replication_expert(command) self.start_replication_expert(command)

View File

@ -1212,6 +1212,20 @@ class ReplicationTest(ConnectingTestCase):
pass pass
cur.consume_replication_stream(consume) # should return at once cur.consume_replication_stream(consume) # should return at once
@skip_before_postgres(9, 4) # slots require 9.4
def test_create_replication_slot(self):
import psycopg2.extras
conn = self.repl_connect(connection_factory=psycopg2.extras.PhysicalReplicationConnection)
if conn is None: return
cur = conn.cursor()
slot = "test_slot1"
try:
cur.create_replication_slot(slot)
self.assertRaises(psycopg2.ProgrammingError, cur.create_replication_slot, slot)
finally:
cur.drop_replication_slot(slot)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)