diff --git a/tests/test_connection.py b/tests/test_connection.py index 68bb6f05..5b296949 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -32,6 +32,7 @@ from StringIO import StringIO import psycopg2 import psycopg2.errorcodes import psycopg2.extensions +ext = psycopg2.extensions from testutils import unittest, decorate_all_tests, skip_if_no_superuser from testutils import skip_before_postgres, skip_after_postgres, skip_before_libpq @@ -125,7 +126,7 @@ class ConnectionTests(ConnectingTestCase): if self.conn.server_version >= 90300: cur.execute("set client_min_messages=debug1") for i in range(0, 100, 10): - sql = " ".join(["create temp table table%d (id serial);" % j for j in range(i, i+10)]) + sql = " ".join(["create temp table table%d (id serial);" % j for j in range(i, i + 10)]) cur.execute(sql) self.assertEqual(50, len(conn.notices)) @@ -151,7 +152,7 @@ class ConnectionTests(ConnectingTestCase): # not limited, but no error for i in range(0, 100, 10): - sql = " ".join(["create temp table table2_%d (id serial);" % j for j in range(i, i+10)]) + sql = " ".join(["create temp table table2_%d (id serial);" % j for j in range(i, i + 10)]) cur.execute(sql) self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]), @@ -172,7 +173,7 @@ class ConnectionTests(ConnectingTestCase): self.assert_(self.conn.server_version) def test_protocol_version(self): - self.assert_(self.conn.protocol_version in (2,3), + self.assert_(self.conn.protocol_version in (2, 3), self.conn.protocol_version) def test_tpc_unsupported(self): @@ -252,7 +253,7 @@ class ConnectionTests(ConnectingTestCase): t1.start() i = 1 for i in range(1000): - cur.execute("select %s;",(i,)) + cur.execute("select %s;", (i,)) conn.commit() while conn.notices: notices.append((1, conn.notices.pop())) @@ -313,16 +314,15 @@ class ConnectionTests(ConnectingTestCase): class ParseDsnTestCase(ConnectingTestCase): def test_parse_dsn(self): from psycopg2 import ProgrammingError - from psycopg2.extensions import parse_dsn - self.assertEqual(parse_dsn('dbname=test user=tester password=secret'), + self.assertEqual(ext.parse_dsn('dbname=test user=tester password=secret'), dict(user='tester', password='secret', dbname='test'), "simple DSN parsed") - self.assertRaises(ProgrammingError, parse_dsn, + self.assertRaises(ProgrammingError, ext.parse_dsn, "dbname=test 2 user=tester password=secret") - self.assertEqual(parse_dsn("dbname='test 2' user=tester password=secret"), + self.assertEqual(ext.parse_dsn("dbname='test 2' user=tester password=secret"), dict(user='tester', password='secret', dbname='test 2'), "DSN with quoting parsed") @@ -332,7 +332,7 @@ class ParseDsnTestCase(ConnectingTestCase): raised = False try: # unterminated quote after dbname: - parse_dsn("dbname='test 2 user=tester password=secret") + ext.parse_dsn("dbname='test 2 user=tester password=secret") except ProgrammingError, e: raised = True self.assertTrue(str(e).find('secret') < 0, @@ -343,16 +343,14 @@ class ParseDsnTestCase(ConnectingTestCase): @skip_before_libpq(9, 2) def test_parse_dsn_uri(self): - from psycopg2.extensions import parse_dsn - - self.assertEqual(parse_dsn('postgresql://tester:secret@/test'), + self.assertEqual(ext.parse_dsn('postgresql://tester:secret@/test'), dict(user='tester', password='secret', dbname='test'), "valid URI dsn parsed") raised = False try: # extra '=' after port value - parse_dsn(dsn='postgresql://tester:secret@/test?port=1111=x') + ext.parse_dsn(dsn='postgresql://tester:secret@/test?port=1111=x') except psycopg2.ProgrammingError, e: raised = True self.assertTrue(str(e).find('secret') < 0, @@ -362,24 +360,76 @@ class ParseDsnTestCase(ConnectingTestCase): self.assertTrue(raised, "ProgrammingError raised due to invalid URI") def test_unicode_value(self): - from psycopg2.extensions import parse_dsn snowman = u"\u2603" - d = parse_dsn('dbname=' + snowman) + d = ext.parse_dsn('dbname=' + snowman) if sys.version_info[0] < 3: self.assertEqual(d['dbname'], snowman.encode('utf8')) else: self.assertEqual(d['dbname'], snowman) def test_unicode_key(self): - from psycopg2.extensions import parse_dsn snowman = u"\u2603" - self.assertRaises(psycopg2.ProgrammingError, parse_dsn, + self.assertRaises(psycopg2.ProgrammingError, ext.parse_dsn, snowman + '=' + snowman) def test_bad_param(self): - from psycopg2.extensions import parse_dsn - self.assertRaises(TypeError, parse_dsn, None) - self.assertRaises(TypeError, parse_dsn, 42) + self.assertRaises(TypeError, ext.parse_dsn, None) + self.assertRaises(TypeError, ext.parse_dsn, 42) + + +class MakeDsnTestCase(ConnectingTestCase): + def assertDsnEqual(self, dsn1, dsn2): + self.assertEqual(set(dsn1.split()), set(dsn2.split())) + + def test_there_has_to_be_something(self): + self.assertRaises(TypeError, ext.make_dsn) + + def test_empty_param(self): + dsn = ext.make_dsn(database='sony', password='') + self.assertDsnEqual(dsn, "dbname=sony password=''") + + def test_escape(self): + dsn = ext.make_dsn(database='hello world') + self.assertEqual(dsn, "dbname='hello world'") + + dsn = ext.make_dsn(database=r'back\slash') + self.assertEqual(dsn, r"dbname=back\\slash") + + dsn = ext.make_dsn(database="quo'te") + self.assertEqual(dsn, r"dbname=quo\'te") + + dsn = ext.make_dsn(database="with\ttab") + self.assertEqual(dsn, "dbname='with\ttab'") + + dsn = ext.make_dsn(database=r"\every thing'") + self.assertEqual(dsn, r"dbname='\\every thing\''") + + def test_params_merging(self): + dsn = ext.make_dsn('dbname=foo', database='bar') + self.assertEqual(dsn, 'dbname=bar') + + dsn = ext.make_dsn('dbname=foo', user='postgres') + self.assertDsnEqual(dsn, 'dbname=foo user=postgres') + + def test_no_dsn_munging(self): + dsn = ext.make_dsn('nosuchparam=whatevs') + self.assertEqual(dsn, 'nosuchparam=whatevs') + + dsn = ext.make_dsn(nosuchparam='whatevs') + self.assertEqual(dsn, 'nosuchparam=whatevs') + + self.assertRaises(psycopg2.ProgrammingError, + ext.make_dsn, 'nosuchparam=whatevs', andthis='either') + + @skip_before_libpq(9, 2) + def test_url_is_cool(self): + dsn = ext.make_dsn('postgresql://tester:secret@/test') + self.assertEqual(dsn, 'postgresql://tester:secret@/test') + + dsn = ext.make_dsn('postgresql://tester:secret@/test', + application_name='woot') + self.assertDsnEqual(dsn, + 'dbname=test user=tester password=secret application_name=woot') class IsolationLevelsTestCase(ConnectingTestCase): @@ -587,7 +637,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.close() return - gids = [ r[0] for r in cur ] + gids = [r[0] for r in cur] for gid in gids: cur.execute("rollback prepared %s;", (gid,)) cnn.close() @@ -761,13 +811,13 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): def test_status_after_recover(self): cnn = self.connect() self.assertEqual(psycopg2.extensions.STATUS_READY, cnn.status) - xns = cnn.tpc_recover() + cnn.tpc_recover() self.assertEqual(psycopg2.extensions.STATUS_READY, cnn.status) cur = cnn.cursor() cur.execute("select 1") self.assertEqual(psycopg2.extensions.STATUS_BEGIN, cnn.status) - xns = cnn.tpc_recover() + cnn.tpc_recover() self.assertEqual(psycopg2.extensions.STATUS_BEGIN, cnn.status) def test_recovered_xids(self): @@ -789,12 +839,12 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn = self.connect() xids = cnn.tpc_recover() - xids = [ xid for xid in xids if xid.database == dbname ] + xids = [xid for xid in xids if xid.database == dbname] xids.sort(key=attrgetter('gtrid')) # check the values returned self.assertEqual(len(okvals), len(xids)) - for (xid, (gid, prepared, owner, database)) in zip (xids, okvals): + for (xid, (gid, prepared, owner, database)) in zip(xids, okvals): self.assertEqual(xid.gtrid, gid) self.assertEqual(xid.prepared, prepared) self.assertEqual(xid.owner, owner) @@ -825,8 +875,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.close() cnn = self.connect() - xids = [ xid for xid in cnn.tpc_recover() - if xid.database == dbname ] + xids = [x for x in cnn.tpc_recover() if x.database == dbname] self.assertEqual(1, len(xids)) xid = xids[0] self.assertEqual(xid.format_id, fid) @@ -847,8 +896,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.close() cnn = self.connect() - xids = [ xid for xid in cnn.tpc_recover() - if xid.database == dbname ] + xids = [x for x in cnn.tpc_recover() if x.database == dbname] self.assertEqual(1, len(xids)) xid = xids[0] self.assertEqual(xid.format_id, None) @@ -893,8 +941,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_begin(x1) cnn.tpc_prepare() cnn.reset() - xid = [ xid for xid in cnn.tpc_recover() - if xid.database == dbname ][0] + xid = [x for x in cnn.tpc_recover() if x.database == dbname][0] self.assertEqual(10, xid.format_id) self.assertEqual('uni', xid.gtrid) self.assertEqual('code', xid.bqual) @@ -909,8 +956,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.tpc_prepare() cnn.reset() - xid = [ xid for xid in cnn.tpc_recover() - if xid.database == dbname ][0] + xid = [x for x in cnn.tpc_recover() if x.database == dbname][0] self.assertEqual(None, xid.format_id) self.assertEqual('transaction-id', xid.gtrid) self.assertEqual(None, xid.bqual) @@ -929,7 +975,7 @@ class ConnectionTwoPhaseTests(ConnectingTestCase): cnn.reset() xids = cnn.tpc_recover() - xid = [ xid for xid in xids if xid.database == dbname ][0] + xid = [x for x in xids if x.database == dbname][0] self.assertEqual(None, xid.format_id) self.assertEqual('dict-connection', xid.gtrid) self.assertEqual(None, xid.bqual) @@ -1182,7 +1228,8 @@ class ReplicationTest(ConnectingTestCase): @skip_before_postgres(9, 0) def test_replication_not_supported(self): conn = self.repl_connect() - if conn is None: return + if conn is None: + return cur = conn.cursor() f = StringIO() self.assertRaises(psycopg2.NotSupportedError,