mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-10 19:16:34 +03:00
999d7a6d01
Close #1619
1955 lines
68 KiB
Python
Executable File
1955 lines
68 KiB
Python
Executable File
#!/usr/bin/env python
|
|
|
|
# test_connection.py - unit test for connection attributes
|
|
#
|
|
# Copyright (C) 2008-2019 James Henstridge <james@jamesh.id.au>
|
|
# Copyright (C) 2020-2021 The Psycopg Team
|
|
#
|
|
# psycopg2 is free software: you can redistribute it and/or modify it
|
|
# under the terms of the GNU Lesser General Public License as published
|
|
# by the Free Software Foundation, either version 3 of the License, or
|
|
# (at your option) any later version.
|
|
#
|
|
# In addition, as a special exception, the copyright holders give
|
|
# permission to link this program with the OpenSSL library (or with
|
|
# modified versions of OpenSSL that use the same license as OpenSSL),
|
|
# and distribute linked combinations including the two.
|
|
#
|
|
# You must obey the GNU Lesser General Public License in all respects for
|
|
# all of the code used other than OpenSSL.
|
|
#
|
|
# psycopg2 is distributed in the hope that it will be useful, but WITHOUT
|
|
# ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
|
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
|
|
# License for more details.
|
|
|
|
import gc
|
|
import os
|
|
import re
|
|
import sys
|
|
import time
|
|
import ctypes
|
|
import shutil
|
|
import tempfile
|
|
import threading
|
|
import subprocess as sp
|
|
from collections import deque
|
|
from operator import attrgetter
|
|
from weakref import ref
|
|
|
|
import psycopg2
|
|
import psycopg2.extras
|
|
from psycopg2 import extensions as ext
|
|
|
|
from .testutils import (
|
|
unittest, skip_if_no_superuser, skip_before_postgres,
|
|
skip_after_postgres, skip_before_libpq, skip_after_libpq,
|
|
ConnectingTestCase, skip_if_tpc_disabled, skip_if_windows, slow,
|
|
skip_if_crdb, crdb_version)
|
|
|
|
from .testconfig import dbhost, dsn, dbname
|
|
|
|
|
|
class ConnectionTests(ConnectingTestCase):
|
|
def test_closed_attribute(self):
|
|
conn = self.conn
|
|
self.assertEqual(conn.closed, False)
|
|
conn.close()
|
|
self.assertEqual(conn.closed, True)
|
|
|
|
def test_close_idempotent(self):
|
|
conn = self.conn
|
|
conn.close()
|
|
conn.close()
|
|
self.assert_(conn.closed)
|
|
|
|
def test_cursor_closed_attribute(self):
|
|
conn = self.conn
|
|
curs = conn.cursor()
|
|
self.assertEqual(curs.closed, False)
|
|
curs.close()
|
|
self.assertEqual(curs.closed, True)
|
|
|
|
# Closing the connection closes the cursor:
|
|
curs = conn.cursor()
|
|
conn.close()
|
|
self.assertEqual(curs.closed, True)
|
|
|
|
@skip_if_crdb("backend pid")
|
|
@skip_before_postgres(8, 4)
|
|
@skip_if_no_superuser
|
|
@skip_if_windows
|
|
def test_cleanup_on_badconn_close(self):
|
|
# ticket #148
|
|
conn = self.conn
|
|
cur = conn.cursor()
|
|
self.assertRaises(psycopg2.OperationalError,
|
|
cur.execute, "select pg_terminate_backend(pg_backend_pid())")
|
|
|
|
self.assertEqual(conn.closed, 2)
|
|
conn.close()
|
|
self.assertEqual(conn.closed, 1)
|
|
|
|
@skip_if_crdb("isolation level")
|
|
def test_reset(self):
|
|
conn = self.conn
|
|
# switch session characteristics
|
|
conn.autocommit = True
|
|
conn.isolation_level = 'serializable'
|
|
conn.readonly = True
|
|
if self.conn.info.server_version >= 90100:
|
|
conn.deferrable = False
|
|
|
|
self.assert_(conn.autocommit)
|
|
self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
self.assert_(conn.readonly is True)
|
|
if self.conn.info.server_version >= 90100:
|
|
self.assert_(conn.deferrable is False)
|
|
|
|
conn.reset()
|
|
# now the session characteristics should be reverted
|
|
self.assert_(not conn.autocommit)
|
|
self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
|
|
self.assert_(conn.readonly is None)
|
|
if self.conn.info.server_version >= 90100:
|
|
self.assert_(conn.deferrable is None)
|
|
|
|
@skip_if_crdb("notice")
|
|
def test_notices(self):
|
|
conn = self.conn
|
|
cur = conn.cursor()
|
|
if self.conn.info.server_version >= 90300:
|
|
cur.execute("set client_min_messages=debug1")
|
|
cur.execute("create temp table chatty (id serial primary key);")
|
|
self.assertEqual("CREATE TABLE", cur.statusmessage)
|
|
self.assert_(conn.notices)
|
|
|
|
@skip_if_crdb("notice")
|
|
def test_notices_consistent_order(self):
|
|
conn = self.conn
|
|
cur = conn.cursor()
|
|
if self.conn.info.server_version >= 90300:
|
|
cur.execute("set client_min_messages=debug1")
|
|
cur.execute("""
|
|
create temp table table1 (id serial);
|
|
create temp table table2 (id serial);
|
|
""")
|
|
cur.execute("""
|
|
create temp table table3 (id serial);
|
|
create temp table table4 (id serial);
|
|
""")
|
|
self.assertEqual(4, len(conn.notices))
|
|
self.assert_('table1' in conn.notices[0])
|
|
self.assert_('table2' in conn.notices[1])
|
|
self.assert_('table3' in conn.notices[2])
|
|
self.assert_('table4' in conn.notices[3])
|
|
|
|
@slow
|
|
@skip_if_crdb("notice")
|
|
def test_notices_limited(self):
|
|
conn = self.conn
|
|
cur = conn.cursor()
|
|
if self.conn.info.server_version >= 90300:
|
|
cur.execute("set client_min_messages=debug1")
|
|
for i in range(0, 100, 10):
|
|
sql = " ".join(["create temp table table%d (id serial);" % j
|
|
for j in range(i, i + 10)])
|
|
cur.execute(sql)
|
|
|
|
self.assertEqual(50, len(conn.notices))
|
|
self.assert_('table99' in conn.notices[-1], conn.notices[-1])
|
|
|
|
@slow
|
|
@skip_if_crdb("notice")
|
|
def test_notices_deque(self):
|
|
conn = self.conn
|
|
self.conn.notices = deque()
|
|
cur = conn.cursor()
|
|
if self.conn.info.server_version >= 90300:
|
|
cur.execute("set client_min_messages=debug1")
|
|
|
|
cur.execute("""
|
|
create temp table table1 (id serial);
|
|
create temp table table2 (id serial);
|
|
""")
|
|
cur.execute("""
|
|
create temp table table3 (id serial);
|
|
create temp table table4 (id serial);""")
|
|
self.assertEqual(len(conn.notices), 4)
|
|
self.assert_('table1' in conn.notices.popleft())
|
|
self.assert_('table2' in conn.notices.popleft())
|
|
self.assert_('table3' in conn.notices.popleft())
|
|
self.assert_('table4' in conn.notices.popleft())
|
|
self.assertEqual(len(conn.notices), 0)
|
|
|
|
# not limited, but no error
|
|
for i in range(0, 100, 10):
|
|
sql = " ".join(["create temp table table2_%d (id serial);" % j
|
|
for j in range(i, i + 10)])
|
|
cur.execute(sql)
|
|
|
|
self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]),
|
|
100)
|
|
|
|
@skip_if_crdb("notice")
|
|
def test_notices_noappend(self):
|
|
conn = self.conn
|
|
self.conn.notices = None # will make an error swallowes ok
|
|
cur = conn.cursor()
|
|
if self.conn.info.server_version >= 90300:
|
|
cur.execute("set client_min_messages=debug1")
|
|
|
|
cur.execute("create temp table table1 (id serial);")
|
|
|
|
self.assertEqual(self.conn.notices, None)
|
|
|
|
def test_server_version(self):
|
|
self.assert_(self.conn.server_version)
|
|
|
|
def test_protocol_version(self):
|
|
self.assert_(self.conn.protocol_version in (2, 3),
|
|
self.conn.protocol_version)
|
|
|
|
def test_tpc_unsupported(self):
|
|
cnn = self.conn
|
|
if cnn.info.server_version >= 80100:
|
|
return self.skipTest("tpc is supported")
|
|
|
|
self.assertRaises(psycopg2.NotSupportedError,
|
|
cnn.xid, 42, "foo", "bar")
|
|
|
|
@slow
|
|
@skip_before_postgres(8, 2)
|
|
def test_concurrent_execution(self):
|
|
def slave():
|
|
cnn = self.connect()
|
|
cur = cnn.cursor()
|
|
cur.execute("select pg_sleep(4)")
|
|
cur.close()
|
|
cnn.close()
|
|
|
|
t1 = threading.Thread(target=slave)
|
|
t2 = threading.Thread(target=slave)
|
|
t0 = time.time()
|
|
t1.start()
|
|
t2.start()
|
|
t1.join()
|
|
t2.join()
|
|
self.assert_(time.time() - t0 < 7,
|
|
"something broken in concurrency")
|
|
|
|
@skip_if_crdb("encoding")
|
|
def test_encoding_name(self):
|
|
self.conn.set_client_encoding("EUC_JP")
|
|
# conn.encoding is 'EUCJP' now.
|
|
cur = self.conn.cursor()
|
|
ext.register_type(ext.UNICODE, cur)
|
|
cur.execute("select 'foo'::text;")
|
|
self.assertEqual(cur.fetchone()[0], 'foo')
|
|
|
|
def test_connect_nonnormal_envvar(self):
|
|
# We must perform encoding normalization at connection time
|
|
self.conn.close()
|
|
oldenc = os.environ.get('PGCLIENTENCODING')
|
|
os.environ['PGCLIENTENCODING'] = 'utf-8' # malformed spelling
|
|
try:
|
|
self.conn = self.connect()
|
|
finally:
|
|
if oldenc is not None:
|
|
os.environ['PGCLIENTENCODING'] = oldenc
|
|
else:
|
|
del os.environ['PGCLIENTENCODING']
|
|
|
|
def test_connect_no_string(self):
|
|
class MyString(str):
|
|
pass
|
|
|
|
conn = psycopg2.connect(MyString(dsn))
|
|
conn.close()
|
|
|
|
def test_weakref(self):
|
|
conn = psycopg2.connect(dsn)
|
|
w = ref(conn)
|
|
conn.close()
|
|
del conn
|
|
gc.collect()
|
|
self.assert_(w() is None)
|
|
|
|
@slow
|
|
def test_commit_concurrency(self):
|
|
# The problem is the one reported in ticket #103. Because of bad
|
|
# status check, we commit even when a commit is already on its way.
|
|
# We can detect this condition by the warnings.
|
|
conn = self.conn
|
|
notices = []
|
|
stop = []
|
|
|
|
def committer():
|
|
while not stop:
|
|
conn.commit()
|
|
while conn.notices:
|
|
notices.append((2, conn.notices.pop()))
|
|
|
|
cur = conn.cursor()
|
|
t1 = threading.Thread(target=committer)
|
|
t1.start()
|
|
for i in range(1000):
|
|
cur.execute("select %s;", (i,))
|
|
conn.commit()
|
|
while conn.notices:
|
|
notices.append((1, conn.notices.pop()))
|
|
|
|
# Stop the committer thread
|
|
stop.append(True)
|
|
|
|
self.assert_(not notices, f"{len(notices)} notices raised")
|
|
|
|
def test_connect_cursor_factory(self):
|
|
conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
|
|
cur = conn.cursor()
|
|
cur.execute("select 1 as a")
|
|
self.assertEqual(cur.fetchone()['a'], 1)
|
|
|
|
def test_cursor_factory(self):
|
|
self.assertEqual(self.conn.cursor_factory, None)
|
|
cur = self.conn.cursor()
|
|
cur.execute("select 1 as a")
|
|
self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())
|
|
|
|
self.conn.cursor_factory = psycopg2.extras.DictCursor
|
|
self.assertEqual(self.conn.cursor_factory, psycopg2.extras.DictCursor)
|
|
cur = self.conn.cursor()
|
|
cur.execute("select 1 as a")
|
|
self.assertEqual(cur.fetchone()['a'], 1)
|
|
|
|
self.conn.cursor_factory = None
|
|
self.assertEqual(self.conn.cursor_factory, None)
|
|
cur = self.conn.cursor()
|
|
cur.execute("select 1 as a")
|
|
self.assertRaises(TypeError, (lambda r: r['a']), cur.fetchone())
|
|
|
|
def test_cursor_factory_none(self):
|
|
# issue #210
|
|
conn = self.connect()
|
|
cur = conn.cursor(cursor_factory=None)
|
|
self.assertEqual(type(cur), ext.cursor)
|
|
|
|
conn = self.connect(cursor_factory=psycopg2.extras.DictCursor)
|
|
cur = conn.cursor(cursor_factory=None)
|
|
self.assertEqual(type(cur), psycopg2.extras.DictCursor)
|
|
|
|
@skip_if_crdb("connect any db")
|
|
def test_failed_init_status(self):
|
|
class SubConnection(ext.connection):
|
|
def __init__(self, dsn):
|
|
try:
|
|
super().__init__(dsn)
|
|
except Exception:
|
|
pass
|
|
|
|
c = SubConnection("dbname=thereisnosuchdatabasemate password=foobar")
|
|
self.assert_(c.closed, "connection failed so it must be closed")
|
|
self.assert_('foobar' not in c.dsn, "password was not obscured")
|
|
|
|
def test_get_native_connection(self):
|
|
conn = self.connect()
|
|
capsule = conn.get_native_connection()
|
|
# we can't do anything else in Python
|
|
self.assertIsNotNone(capsule)
|
|
|
|
def test_pgconn_ptr(self):
|
|
conn = self.connect()
|
|
self.assert_(conn.pgconn_ptr is not None)
|
|
|
|
try:
|
|
f = self.libpq.PQserverVersion
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
f.argtypes = [ctypes.c_void_p]
|
|
f.restype = ctypes.c_int
|
|
ver = f(conn.pgconn_ptr)
|
|
if ver == 0 and sys.platform == 'darwin':
|
|
return self.skipTest(
|
|
"I don't know why this func returns 0 on OSX")
|
|
|
|
self.assertEqual(ver, conn.server_version)
|
|
|
|
conn.close()
|
|
self.assert_(conn.pgconn_ptr is None)
|
|
|
|
@slow
|
|
def test_multiprocess_close(self):
|
|
dir = tempfile.mkdtemp()
|
|
try:
|
|
with open(os.path.join(dir, "mptest.py"), 'w') as f:
|
|
f.write(f"""import time
|
|
import psycopg2
|
|
|
|
def thread():
|
|
conn = psycopg2.connect({dsn!r})
|
|
curs = conn.cursor()
|
|
for i in range(10):
|
|
curs.execute("select 1")
|
|
time.sleep(0.1)
|
|
|
|
def process():
|
|
time.sleep(0.2)
|
|
""")
|
|
|
|
script = ("""\
|
|
import sys
|
|
sys.path.insert(0, {dir!r})
|
|
import time
|
|
import threading
|
|
import multiprocessing
|
|
import mptest
|
|
|
|
t = threading.Thread(target=mptest.thread, name='mythread')
|
|
t.start()
|
|
time.sleep(0.2)
|
|
multiprocessing.Process(target=mptest.process, name='myprocess').start()
|
|
t.join()
|
|
""".format(dir=dir))
|
|
|
|
out = sp.check_output(
|
|
[sys.executable, '-c', script], stderr=sp.STDOUT)
|
|
self.assertEqual(out, b'', out)
|
|
finally:
|
|
shutil.rmtree(dir, ignore_errors=True)
|
|
|
|
|
|
class ParseDsnTestCase(ConnectingTestCase):
|
|
def test_parse_dsn(self):
|
|
self.assertEqual(
|
|
ext.parse_dsn('dbname=test user=tester password=secret'),
|
|
dict(user='tester', password='secret', dbname='test'),
|
|
"simple DSN parsed")
|
|
|
|
self.assertRaises(psycopg2.ProgrammingError, ext.parse_dsn,
|
|
"dbname=test 2 user=tester password=secret")
|
|
|
|
self.assertEqual(
|
|
ext.parse_dsn("dbname='test 2' user=tester password=secret"),
|
|
dict(user='tester', password='secret', dbname='test 2'),
|
|
"DSN with quoting parsed")
|
|
|
|
# Can't really use assertRaisesRegexp() here since we need to
|
|
# make sure that secret is *not* exposed in the error message.
|
|
raised = False
|
|
try:
|
|
# unterminated quote after dbname:
|
|
ext.parse_dsn("dbname='test 2 user=tester password=secret")
|
|
except psycopg2.ProgrammingError as e:
|
|
raised = True
|
|
self.assertTrue(str(e).find('secret') < 0,
|
|
"DSN was not exposed in error message")
|
|
self.assertTrue(raised, "ProgrammingError raised due to invalid DSN")
|
|
|
|
@skip_before_libpq(9, 2)
|
|
def test_parse_dsn_uri(self):
|
|
self.assertEqual(ext.parse_dsn('postgresql://tester:secret@/test'),
|
|
dict(user='tester', password='secret', dbname='test'),
|
|
"valid URI dsn parsed")
|
|
|
|
raised = False
|
|
try:
|
|
# extra '=' after port value
|
|
ext.parse_dsn(dsn='postgresql://tester:secret@/test?port=1111=x')
|
|
except psycopg2.ProgrammingError as e:
|
|
raised = True
|
|
self.assertTrue(str(e).find('secret') < 0,
|
|
"URI was not exposed in error message")
|
|
self.assertTrue(raised, "ProgrammingError raised due to invalid URI")
|
|
|
|
def test_unicode_value(self):
|
|
snowman = "\u2603"
|
|
d = ext.parse_dsn('dbname=' + snowman)
|
|
self.assertEqual(d['dbname'], snowman)
|
|
|
|
def test_unicode_key(self):
|
|
snowman = "\u2603"
|
|
self.assertRaises(psycopg2.ProgrammingError, ext.parse_dsn,
|
|
snowman + '=' + snowman)
|
|
|
|
def test_bad_param(self):
|
|
self.assertRaises(TypeError, ext.parse_dsn, None)
|
|
self.assertRaises(TypeError, ext.parse_dsn, 42)
|
|
|
|
def test_str_subclass(self):
|
|
class MyString(str):
|
|
pass
|
|
|
|
res = ext.parse_dsn(MyString("dbname=test"))
|
|
self.assertEqual(res, {'dbname': 'test'})
|
|
|
|
|
|
class MakeDsnTestCase(ConnectingTestCase):
|
|
def test_empty_arguments(self):
|
|
self.assertEqual(ext.make_dsn(), '')
|
|
|
|
def test_empty_string(self):
|
|
dsn = ext.make_dsn('')
|
|
self.assertEqual(dsn, '')
|
|
|
|
def test_params_validation(self):
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
ext.make_dsn, 'dbnamo=a')
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
ext.make_dsn, dbnamo='a')
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
ext.make_dsn, 'dbname=a', nosuchparam='b')
|
|
|
|
def test_empty_param(self):
|
|
dsn = ext.make_dsn(dbname='sony', password='')
|
|
self.assertDsnEqual(dsn, "dbname=sony password=''")
|
|
|
|
def test_escape(self):
|
|
dsn = ext.make_dsn(dbname='hello world')
|
|
self.assertEqual(dsn, "dbname='hello world'")
|
|
|
|
dsn = ext.make_dsn(dbname=r'back\slash')
|
|
self.assertEqual(dsn, r"dbname=back\\slash")
|
|
|
|
dsn = ext.make_dsn(dbname="quo'te")
|
|
self.assertEqual(dsn, r"dbname=quo\'te")
|
|
|
|
dsn = ext.make_dsn(dbname="with\ttab")
|
|
self.assertEqual(dsn, "dbname='with\ttab'")
|
|
|
|
dsn = ext.make_dsn(dbname=r"\every thing'")
|
|
self.assertEqual(dsn, r"dbname='\\every thing\''")
|
|
|
|
def test_database_is_a_keyword(self):
|
|
self.assertEqual(ext.make_dsn(database='sigh'), "dbname=sigh")
|
|
|
|
def test_params_merging(self):
|
|
dsn = ext.make_dsn('dbname=foo host=bar', host='baz')
|
|
self.assertDsnEqual(dsn, 'dbname=foo host=baz')
|
|
|
|
dsn = ext.make_dsn('dbname=foo', user='postgres')
|
|
self.assertDsnEqual(dsn, 'dbname=foo user=postgres')
|
|
|
|
def test_no_dsn_munging(self):
|
|
dsnin = 'dbname=a host=b user=c password=d'
|
|
dsn = ext.make_dsn(dsnin)
|
|
self.assertEqual(dsn, dsnin)
|
|
|
|
def test_null_args(self):
|
|
dsn = ext.make_dsn("dbname=foo", user="bar", password=None)
|
|
self.assertDsnEqual(dsn, "dbname=foo user=bar")
|
|
|
|
@skip_before_libpq(9, 2)
|
|
def test_url_is_cool(self):
|
|
url = 'postgresql://tester:secret@/test?application_name=wat'
|
|
dsn = ext.make_dsn(url)
|
|
self.assertEqual(dsn, url)
|
|
|
|
dsn = ext.make_dsn(url, application_name='woot')
|
|
self.assertDsnEqual(dsn,
|
|
'dbname=test user=tester password=secret application_name=woot')
|
|
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
ext.make_dsn, 'postgresql://tester:secret@/test?nosuch=param')
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
ext.make_dsn, url, nosuch="param")
|
|
|
|
@skip_before_libpq(9, 3)
|
|
def test_get_dsn_parameters(self):
|
|
conn = self.connect()
|
|
d = conn.get_dsn_parameters()
|
|
self.assertEqual(d['dbname'], dbname) # the only param we can check reliably
|
|
self.assert_('password' not in d, d)
|
|
|
|
|
|
class IsolationLevelsTestCase(ConnectingTestCase):
|
|
|
|
def setUp(self):
|
|
ConnectingTestCase.setUp(self)
|
|
|
|
conn = self.connect()
|
|
cur = conn.cursor()
|
|
if crdb_version(conn) is not None:
|
|
cur.execute("create table if not exists isolevel (id integer)")
|
|
cur.execute("truncate isolevel")
|
|
conn.commit()
|
|
return
|
|
|
|
try:
|
|
cur.execute("drop table isolevel;")
|
|
except psycopg2.ProgrammingError:
|
|
conn.rollback()
|
|
try:
|
|
cur.execute("create table isolevel (id integer);")
|
|
conn.commit()
|
|
finally:
|
|
conn.close()
|
|
|
|
def test_isolation_level(self):
|
|
conn = self.connect()
|
|
self.assertEqual(
|
|
conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_DEFAULT)
|
|
|
|
def test_encoding(self):
|
|
conn = self.connect()
|
|
self.assert_(conn.encoding in ext.encodings)
|
|
|
|
@skip_if_crdb("isolation level")
|
|
def test_set_isolation_level(self):
|
|
conn = self.connect()
|
|
curs = conn.cursor()
|
|
|
|
levels = [
|
|
('read uncommitted',
|
|
ext.ISOLATION_LEVEL_READ_UNCOMMITTED),
|
|
('read committed', ext.ISOLATION_LEVEL_READ_COMMITTED),
|
|
('repeatable read', ext.ISOLATION_LEVEL_REPEATABLE_READ),
|
|
('serializable', ext.ISOLATION_LEVEL_SERIALIZABLE),
|
|
]
|
|
for name, level in levels:
|
|
conn.set_isolation_level(level)
|
|
|
|
# the only values available on prehistoric PG versions
|
|
if conn.info.server_version < 80000:
|
|
if level in (
|
|
ext.ISOLATION_LEVEL_READ_UNCOMMITTED,
|
|
ext.ISOLATION_LEVEL_REPEATABLE_READ):
|
|
name, level = levels[levels.index((name, level)) + 1]
|
|
|
|
self.assertEqual(conn.isolation_level, level)
|
|
|
|
curs.execute('show transaction_isolation;')
|
|
got_name = curs.fetchone()[0]
|
|
|
|
self.assertEqual(name, got_name)
|
|
conn.commit()
|
|
|
|
self.assertRaises(ValueError, conn.set_isolation_level, -1)
|
|
self.assertRaises(ValueError, conn.set_isolation_level, 5)
|
|
|
|
def test_set_isolation_level_autocommit(self):
|
|
conn = self.connect()
|
|
curs = conn.cursor()
|
|
|
|
conn.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
|
|
self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_DEFAULT)
|
|
self.assert_(conn.autocommit)
|
|
|
|
conn.isolation_level = 'serializable'
|
|
self.assertEqual(conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
self.assert_(conn.autocommit)
|
|
|
|
curs.execute('show transaction_isolation;')
|
|
self.assertEqual(curs.fetchone()[0], 'serializable')
|
|
|
|
@skip_if_crdb("isolation level")
|
|
def test_set_isolation_level_default(self):
|
|
conn = self.connect()
|
|
curs = conn.cursor()
|
|
|
|
conn.autocommit = True
|
|
curs.execute("set default_transaction_isolation to 'read committed'")
|
|
|
|
conn.autocommit = False
|
|
conn.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
self.assertEqual(conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
curs.execute("show transaction_isolation")
|
|
self.assertEqual(curs.fetchone()[0], "serializable")
|
|
|
|
conn.rollback()
|
|
conn.set_isolation_level(ext.ISOLATION_LEVEL_DEFAULT)
|
|
curs.execute("show transaction_isolation")
|
|
self.assertEqual(curs.fetchone()[0], "read committed")
|
|
|
|
def test_set_isolation_level_abort(self):
|
|
conn = self.connect()
|
|
cur = conn.cursor()
|
|
|
|
self.assertEqual(ext.TRANSACTION_STATUS_IDLE,
|
|
conn.info.transaction_status)
|
|
cur.execute("insert into isolevel values (10);")
|
|
self.assertEqual(ext.TRANSACTION_STATUS_INTRANS,
|
|
conn.info.transaction_status)
|
|
|
|
conn.set_isolation_level(
|
|
psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE)
|
|
self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
|
|
conn.info.transaction_status)
|
|
cur.execute("select count(*) from isolevel;")
|
|
self.assertEqual(0, cur.fetchone()[0])
|
|
|
|
cur.execute("insert into isolevel values (10);")
|
|
self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_INTRANS,
|
|
conn.info.transaction_status)
|
|
conn.set_isolation_level(
|
|
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT)
|
|
self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
|
|
conn.info.transaction_status)
|
|
cur.execute("select count(*) from isolevel;")
|
|
self.assertEqual(0, cur.fetchone()[0])
|
|
|
|
cur.execute("insert into isolevel values (10);")
|
|
self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
|
|
conn.info.transaction_status)
|
|
conn.set_isolation_level(
|
|
psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
|
|
self.assertEqual(psycopg2.extensions.TRANSACTION_STATUS_IDLE,
|
|
conn.info.transaction_status)
|
|
cur.execute("select count(*) from isolevel;")
|
|
self.assertEqual(1, cur.fetchone()[0])
|
|
self.assertEqual(conn.isolation_level,
|
|
psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED)
|
|
|
|
def test_isolation_level_autocommit(self):
|
|
cnn1 = self.connect()
|
|
cnn2 = self.connect()
|
|
cnn2.set_isolation_level(ext.ISOLATION_LEVEL_AUTOCOMMIT)
|
|
|
|
cur1 = cnn1.cursor()
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(0, cur1.fetchone()[0])
|
|
cnn1.commit()
|
|
|
|
cur2 = cnn2.cursor()
|
|
cur2.execute("insert into isolevel values (10);")
|
|
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(1, cur1.fetchone()[0])
|
|
|
|
@skip_if_crdb("isolation level")
|
|
def test_isolation_level_read_committed(self):
|
|
cnn1 = self.connect()
|
|
cnn2 = self.connect()
|
|
cnn2.set_isolation_level(ext.ISOLATION_LEVEL_READ_COMMITTED)
|
|
|
|
cur1 = cnn1.cursor()
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(0, cur1.fetchone()[0])
|
|
cnn1.commit()
|
|
|
|
cur2 = cnn2.cursor()
|
|
cur2.execute("insert into isolevel values (10);")
|
|
cur1.execute("insert into isolevel values (20);")
|
|
|
|
cur2.execute("select count(*) from isolevel;")
|
|
self.assertEqual(1, cur2.fetchone()[0])
|
|
cnn1.commit()
|
|
cur2.execute("select count(*) from isolevel;")
|
|
self.assertEqual(2, cur2.fetchone()[0])
|
|
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(1, cur1.fetchone()[0])
|
|
cnn2.commit()
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(2, cur1.fetchone()[0])
|
|
|
|
@skip_if_crdb("isolation level")
|
|
def test_isolation_level_serializable(self):
|
|
cnn1 = self.connect()
|
|
cnn2 = self.connect()
|
|
cnn2.set_isolation_level(ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
|
|
cur1 = cnn1.cursor()
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(0, cur1.fetchone()[0])
|
|
cnn1.commit()
|
|
|
|
cur2 = cnn2.cursor()
|
|
cur2.execute("insert into isolevel values (10);")
|
|
cur1.execute("insert into isolevel values (20);")
|
|
|
|
cur2.execute("select count(*) from isolevel;")
|
|
self.assertEqual(1, cur2.fetchone()[0])
|
|
cnn1.commit()
|
|
cur2.execute("select count(*) from isolevel;")
|
|
self.assertEqual(1, cur2.fetchone()[0])
|
|
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(1, cur1.fetchone()[0])
|
|
cnn2.commit()
|
|
cur1.execute("select count(*) from isolevel;")
|
|
self.assertEqual(2, cur1.fetchone()[0])
|
|
|
|
cur2.execute("select count(*) from isolevel;")
|
|
self.assertEqual(2, cur2.fetchone()[0])
|
|
|
|
def test_isolation_level_closed(self):
|
|
cnn = self.connect()
|
|
cnn.close()
|
|
self.assertRaises(psycopg2.InterfaceError,
|
|
cnn.set_isolation_level, 0)
|
|
self.assertRaises(psycopg2.InterfaceError,
|
|
cnn.set_isolation_level, 1)
|
|
|
|
@skip_if_crdb("isolation level")
|
|
def test_setattr_isolation_level_int(self):
|
|
cur = self.conn.cursor()
|
|
self.conn.isolation_level = ext.ISOLATION_LEVEL_SERIALIZABLE
|
|
self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.isolation_level = ext.ISOLATION_LEVEL_REPEATABLE_READ
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_REPEATABLE_READ)
|
|
self.assertEqual(cur.fetchone()[0], 'repeatable read')
|
|
else:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_COMMITTED
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_READ_COMMITTED)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
self.conn.isolation_level = ext.ISOLATION_LEVEL_READ_UNCOMMITTED
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_READ_UNCOMMITTED)
|
|
self.assertEqual(cur.fetchone()[0], 'read uncommitted')
|
|
else:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_READ_COMMITTED)
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
self.assertEqual(ext.ISOLATION_LEVEL_DEFAULT, None)
|
|
self.conn.isolation_level = ext.ISOLATION_LEVEL_DEFAULT
|
|
self.assertEqual(self.conn.isolation_level, None)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
isol = cur.fetchone()[0]
|
|
cur.execute("SHOW default_transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], isol)
|
|
|
|
@skip_if_crdb("isolation level")
|
|
def test_setattr_isolation_level_str(self):
|
|
cur = self.conn.cursor()
|
|
self.conn.isolation_level = "serializable"
|
|
self.assertEqual(self.conn.isolation_level, ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.isolation_level = "repeatable read"
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_REPEATABLE_READ)
|
|
self.assertEqual(cur.fetchone()[0], 'repeatable read')
|
|
else:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.isolation_level = "read committed"
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_READ_COMMITTED)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
self.conn.isolation_level = "read uncommitted"
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_READ_UNCOMMITTED)
|
|
self.assertEqual(cur.fetchone()[0], 'read uncommitted')
|
|
else:
|
|
self.assertEqual(self.conn.isolation_level,
|
|
ext.ISOLATION_LEVEL_READ_COMMITTED)
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
self.conn.isolation_level = "default"
|
|
self.assertEqual(self.conn.isolation_level, None)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
isol = cur.fetchone()[0]
|
|
cur.execute("SHOW default_transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], isol)
|
|
|
|
def test_setattr_isolation_level_invalid(self):
|
|
self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 0)
|
|
self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', -1)
|
|
self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 5)
|
|
self.assertRaises(ValueError, setattr, self.conn, 'isolation_level', 'bah')
|
|
|
|
def test_attribs_segfault(self):
|
|
# bug #790
|
|
for i in range(10000):
|
|
self.conn.autocommit
|
|
self.conn.readonly
|
|
self.conn.deferrable
|
|
self.conn.isolation_level
|
|
|
|
|
|
@skip_if_tpc_disabled
|
|
class ConnectionTwoPhaseTests(ConnectingTestCase):
|
|
def setUp(self):
|
|
ConnectingTestCase.setUp(self)
|
|
|
|
self.make_test_table()
|
|
self.clear_test_xacts()
|
|
|
|
def tearDown(self):
|
|
self.clear_test_xacts()
|
|
ConnectingTestCase.tearDown(self)
|
|
|
|
def clear_test_xacts(self):
|
|
"""Rollback all the prepared transaction in the testing db."""
|
|
cnn = self.connect()
|
|
cnn.set_isolation_level(0)
|
|
cur = cnn.cursor()
|
|
try:
|
|
cur.execute(
|
|
"select gid from pg_prepared_xacts where database = %s",
|
|
(dbname,))
|
|
except psycopg2.ProgrammingError:
|
|
cnn.rollback()
|
|
cnn.close()
|
|
return
|
|
|
|
gids = [r[0] for r in cur]
|
|
for gid in gids:
|
|
cur.execute("rollback prepared %s;", (gid,))
|
|
cnn.close()
|
|
|
|
def make_test_table(self):
|
|
cnn = self.connect()
|
|
cur = cnn.cursor()
|
|
if crdb_version(cnn) is not None:
|
|
cur.execute("CREATE TABLE IF NOT EXISTS test_tpc (data text)")
|
|
cur.execute("TRUNCATE test_tpc")
|
|
cnn.commit()
|
|
cnn.close()
|
|
return
|
|
|
|
try:
|
|
cur.execute("DROP TABLE test_tpc;")
|
|
except psycopg2.ProgrammingError:
|
|
cnn.rollback()
|
|
try:
|
|
cur.execute("CREATE TABLE test_tpc (data text);")
|
|
cnn.commit()
|
|
finally:
|
|
cnn.close()
|
|
|
|
def count_xacts(self):
|
|
"""Return the number of prepared xacts currently in the test db."""
|
|
cnn = self.connect()
|
|
cur = cnn.cursor()
|
|
cur.execute("""
|
|
select count(*) from pg_prepared_xacts
|
|
where database = %s;""",
|
|
(dbname,))
|
|
rv = cur.fetchone()[0]
|
|
cnn.close()
|
|
return rv
|
|
|
|
def count_test_records(self):
|
|
"""Return the number of records in the test table."""
|
|
cnn = self.connect()
|
|
cur = cnn.cursor()
|
|
cur.execute("select count(*) from test_tpc;")
|
|
rv = cur.fetchone()[0]
|
|
cnn.close()
|
|
return rv
|
|
|
|
def test_tpc_commit(self):
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
|
|
cnn.tpc_begin(xid)
|
|
self.assertEqual(cnn.status, ext.STATUS_BEGIN)
|
|
|
|
cur = cnn.cursor()
|
|
cur.execute("insert into test_tpc values ('test_tpc_commit');")
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_prepare()
|
|
self.assertEqual(cnn.status, ext.STATUS_PREPARED)
|
|
self.assertEqual(1, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_commit()
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(1, self.count_test_records())
|
|
|
|
def test_tpc_commit_one_phase(self):
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
|
|
cnn.tpc_begin(xid)
|
|
self.assertEqual(cnn.status, ext.STATUS_BEGIN)
|
|
|
|
cur = cnn.cursor()
|
|
cur.execute("insert into test_tpc values ('test_tpc_commit_1p');")
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_commit()
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(1, self.count_test_records())
|
|
|
|
def test_tpc_commit_recovered(self):
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
|
|
cnn.tpc_begin(xid)
|
|
self.assertEqual(cnn.status, ext.STATUS_BEGIN)
|
|
|
|
cur = cnn.cursor()
|
|
cur.execute("insert into test_tpc values ('test_tpc_commit_rec');")
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_prepare()
|
|
cnn.close()
|
|
self.assertEqual(1, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
cnn.tpc_commit(xid)
|
|
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(1, self.count_test_records())
|
|
|
|
def test_tpc_rollback(self):
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
|
|
cnn.tpc_begin(xid)
|
|
self.assertEqual(cnn.status, ext.STATUS_BEGIN)
|
|
|
|
cur = cnn.cursor()
|
|
cur.execute("insert into test_tpc values ('test_tpc_rollback');")
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_prepare()
|
|
self.assertEqual(cnn.status, ext.STATUS_PREPARED)
|
|
self.assertEqual(1, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_rollback()
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
def test_tpc_rollback_one_phase(self):
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
|
|
cnn.tpc_begin(xid)
|
|
self.assertEqual(cnn.status, ext.STATUS_BEGIN)
|
|
|
|
cur = cnn.cursor()
|
|
cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');")
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_rollback()
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
def test_tpc_rollback_recovered(self):
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
|
|
cnn.tpc_begin(xid)
|
|
self.assertEqual(cnn.status, ext.STATUS_BEGIN)
|
|
|
|
cur = cnn.cursor()
|
|
cur.execute("insert into test_tpc values ('test_tpc_commit_rec');")
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn.tpc_prepare()
|
|
cnn.close()
|
|
self.assertEqual(1, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
cnn = self.connect()
|
|
xid = cnn.xid(1, "gtrid", "bqual")
|
|
cnn.tpc_rollback(xid)
|
|
|
|
self.assertEqual(cnn.status, ext.STATUS_READY)
|
|
self.assertEqual(0, self.count_xacts())
|
|
self.assertEqual(0, self.count_test_records())
|
|
|
|
def test_status_after_recover(self):
|
|
cnn = self.connect()
|
|
self.assertEqual(ext.STATUS_READY, cnn.status)
|
|
cnn.tpc_recover()
|
|
self.assertEqual(ext.STATUS_READY, cnn.status)
|
|
|
|
cur = cnn.cursor()
|
|
cur.execute("select 1")
|
|
self.assertEqual(ext.STATUS_BEGIN, cnn.status)
|
|
cnn.tpc_recover()
|
|
self.assertEqual(ext.STATUS_BEGIN, cnn.status)
|
|
|
|
def test_recovered_xids(self):
|
|
# insert a few test xns
|
|
cnn = self.connect()
|
|
cnn.set_isolation_level(0)
|
|
cur = cnn.cursor()
|
|
cur.execute("begin; prepare transaction '1-foo';")
|
|
cur.execute("begin; prepare transaction '2-bar';")
|
|
|
|
# read the values to return
|
|
cur.execute("""
|
|
select gid, prepared, owner, database
|
|
from pg_prepared_xacts
|
|
where database = %s;""",
|
|
(dbname,))
|
|
okvals = cur.fetchall()
|
|
okvals.sort()
|
|
|
|
cnn = self.connect()
|
|
xids = cnn.tpc_recover()
|
|
xids = [xid for xid in xids if xid.database == dbname]
|
|
xids.sort(key=attrgetter('gtrid'))
|
|
|
|
# check the values returned
|
|
self.assertEqual(len(okvals), len(xids))
|
|
for (xid, (gid, prepared, owner, database)) in zip(xids, okvals):
|
|
self.assertEqual(xid.gtrid, gid)
|
|
self.assertEqual(xid.prepared, prepared)
|
|
self.assertEqual(xid.owner, owner)
|
|
self.assertEqual(xid.database, database)
|
|
|
|
def test_xid_encoding(self):
|
|
cnn = self.connect()
|
|
xid = cnn.xid(42, "gtrid", "bqual")
|
|
cnn.tpc_begin(xid)
|
|
cnn.tpc_prepare()
|
|
|
|
cnn = self.connect()
|
|
cur = cnn.cursor()
|
|
cur.execute("select gid from pg_prepared_xacts where database = %s;",
|
|
(dbname,))
|
|
self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0])
|
|
|
|
@slow
|
|
def test_xid_roundtrip(self):
|
|
for fid, gtrid, bqual in [
|
|
(0, "", ""),
|
|
(42, "gtrid", "bqual"),
|
|
(0x7fffffff, "x" * 64, "y" * 64),
|
|
]:
|
|
cnn = self.connect()
|
|
xid = cnn.xid(fid, gtrid, bqual)
|
|
cnn.tpc_begin(xid)
|
|
cnn.tpc_prepare()
|
|
cnn.close()
|
|
|
|
cnn = self.connect()
|
|
xids = [x for x in cnn.tpc_recover() if x.database == dbname]
|
|
self.assertEqual(1, len(xids))
|
|
xid = xids[0]
|
|
self.assertEqual(xid.format_id, fid)
|
|
self.assertEqual(xid.gtrid, gtrid)
|
|
self.assertEqual(xid.bqual, bqual)
|
|
|
|
cnn.tpc_rollback(xid)
|
|
|
|
@slow
|
|
def test_unparsed_roundtrip(self):
|
|
for tid in [
|
|
'',
|
|
'hello, world!',
|
|
'x' * 199, # PostgreSQL's limit in transaction id length
|
|
]:
|
|
cnn = self.connect()
|
|
cnn.tpc_begin(tid)
|
|
cnn.tpc_prepare()
|
|
cnn.close()
|
|
|
|
cnn = self.connect()
|
|
xids = [x for x in cnn.tpc_recover() if x.database == dbname]
|
|
self.assertEqual(1, len(xids))
|
|
xid = xids[0]
|
|
self.assertEqual(xid.format_id, None)
|
|
self.assertEqual(xid.gtrid, tid)
|
|
self.assertEqual(xid.bqual, None)
|
|
|
|
cnn.tpc_rollback(xid)
|
|
|
|
def test_xid_construction(self):
|
|
x1 = ext.Xid(74, 'foo', 'bar')
|
|
self.assertEqual(74, x1.format_id)
|
|
self.assertEqual('foo', x1.gtrid)
|
|
self.assertEqual('bar', x1.bqual)
|
|
|
|
def test_xid_from_string(self):
|
|
x2 = ext.Xid.from_string('42_Z3RyaWQ=_YnF1YWw=')
|
|
self.assertEqual(42, x2.format_id)
|
|
self.assertEqual('gtrid', x2.gtrid)
|
|
self.assertEqual('bqual', x2.bqual)
|
|
|
|
x3 = ext.Xid.from_string('99_xxx_yyy')
|
|
self.assertEqual(None, x3.format_id)
|
|
self.assertEqual('99_xxx_yyy', x3.gtrid)
|
|
self.assertEqual(None, x3.bqual)
|
|
|
|
def test_xid_to_string(self):
|
|
x1 = ext.Xid.from_string('42_Z3RyaWQ=_YnF1YWw=')
|
|
self.assertEqual(str(x1), '42_Z3RyaWQ=_YnF1YWw=')
|
|
|
|
x2 = ext.Xid.from_string('99_xxx_yyy')
|
|
self.assertEqual(str(x2), '99_xxx_yyy')
|
|
|
|
def test_xid_unicode(self):
|
|
cnn = self.connect()
|
|
x1 = cnn.xid(10, 'uni', 'code')
|
|
cnn.tpc_begin(x1)
|
|
cnn.tpc_prepare()
|
|
cnn.reset()
|
|
xid = [x for x in cnn.tpc_recover() if x.database == dbname][0]
|
|
self.assertEqual(10, xid.format_id)
|
|
self.assertEqual('uni', xid.gtrid)
|
|
self.assertEqual('code', xid.bqual)
|
|
|
|
def test_xid_unicode_unparsed(self):
|
|
# We don't expect people shooting snowmen as transaction ids,
|
|
# so if something explodes in an encode error I don't mind.
|
|
# Let's just check unicode is accepted as type.
|
|
cnn = self.connect()
|
|
cnn.set_client_encoding('utf8')
|
|
cnn.tpc_begin("transaction-id")
|
|
cnn.tpc_prepare()
|
|
cnn.reset()
|
|
|
|
xid = [x for x in cnn.tpc_recover() if x.database == dbname][0]
|
|
self.assertEqual(None, xid.format_id)
|
|
self.assertEqual('transaction-id', xid.gtrid)
|
|
self.assertEqual(None, xid.bqual)
|
|
|
|
def test_cancel_fails_prepared(self):
|
|
cnn = self.connect()
|
|
cnn.tpc_begin('cancel')
|
|
cnn.tpc_prepare()
|
|
self.assertRaises(psycopg2.ProgrammingError, cnn.cancel)
|
|
|
|
def test_tpc_recover_non_dbapi_connection(self):
|
|
cnn = self.connect(connection_factory=psycopg2.extras.RealDictConnection)
|
|
cnn.tpc_begin('dict-connection')
|
|
cnn.tpc_prepare()
|
|
cnn.reset()
|
|
|
|
xids = cnn.tpc_recover()
|
|
xid = [x for x in xids if x.database == dbname][0]
|
|
self.assertEqual(None, xid.format_id)
|
|
self.assertEqual('dict-connection', xid.gtrid)
|
|
self.assertEqual(None, xid.bqual)
|
|
|
|
|
|
@skip_if_crdb("isolation level")
|
|
class TransactionControlTests(ConnectingTestCase):
|
|
def test_closed(self):
|
|
self.conn.close()
|
|
self.assertRaises(psycopg2.InterfaceError,
|
|
self.conn.set_session,
|
|
ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
|
|
def test_not_in_transaction(self):
|
|
cur = self.conn.cursor()
|
|
cur.execute("select 1")
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
self.conn.set_session,
|
|
ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
|
|
def test_set_isolation_level(self):
|
|
cur = self.conn.cursor()
|
|
self.conn.set_session(
|
|
ext.ISOLATION_LEVEL_SERIALIZABLE)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session(
|
|
ext.ISOLATION_LEVEL_REPEATABLE_READ)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(cur.fetchone()[0], 'repeatable read')
|
|
else:
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session(
|
|
isolation_level=ext.ISOLATION_LEVEL_READ_COMMITTED)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session(
|
|
isolation_level=ext.ISOLATION_LEVEL_READ_UNCOMMITTED)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(cur.fetchone()[0], 'read uncommitted')
|
|
else:
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
def test_set_isolation_level_str(self):
|
|
cur = self.conn.cursor()
|
|
self.conn.set_session("serializable")
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session("repeatable read")
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(cur.fetchone()[0], 'repeatable read')
|
|
else:
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session("read committed")
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session("read uncommitted")
|
|
cur.execute("SHOW transaction_isolation;")
|
|
if self.conn.info.server_version > 80000:
|
|
self.assertEqual(cur.fetchone()[0], 'read uncommitted')
|
|
else:
|
|
self.assertEqual(cur.fetchone()[0], 'read committed')
|
|
self.conn.rollback()
|
|
|
|
def test_bad_isolation_level(self):
|
|
self.assertRaises(ValueError, self.conn.set_session, 0)
|
|
self.assertRaises(ValueError, self.conn.set_session, 5)
|
|
self.assertRaises(ValueError, self.conn.set_session, 'whatever')
|
|
|
|
def test_set_read_only(self):
|
|
self.assert_(self.conn.readonly is None)
|
|
|
|
cur = self.conn.cursor()
|
|
self.conn.set_session(readonly=True)
|
|
self.assert_(self.conn.readonly is True)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.conn.rollback()
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session(readonly=False)
|
|
self.assert_(self.conn.readonly is False)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'off')
|
|
self.conn.rollback()
|
|
|
|
def test_setattr_read_only(self):
|
|
cur = self.conn.cursor()
|
|
self.conn.readonly = True
|
|
self.assert_(self.conn.readonly is True)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.assertRaises(self.conn.ProgrammingError,
|
|
setattr, self.conn, 'readonly', False)
|
|
self.assert_(self.conn.readonly is True)
|
|
self.conn.rollback()
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.conn.rollback()
|
|
|
|
cur = self.conn.cursor()
|
|
self.conn.readonly = None
|
|
self.assert_(self.conn.readonly is None)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server
|
|
self.conn.rollback()
|
|
|
|
self.conn.readonly = False
|
|
self.assert_(self.conn.readonly is False)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'off')
|
|
self.conn.rollback()
|
|
|
|
def test_set_default(self):
|
|
cur = self.conn.cursor()
|
|
cur.execute("SHOW transaction_isolation;")
|
|
isolevel = cur.fetchone()[0]
|
|
cur.execute("SHOW transaction_read_only;")
|
|
readonly = cur.fetchone()[0]
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session(isolation_level='serializable', readonly=True)
|
|
self.conn.set_session(isolation_level='default', readonly='default')
|
|
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], isolevel)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], readonly)
|
|
|
|
@skip_before_postgres(9, 1)
|
|
def test_set_deferrable(self):
|
|
self.assert_(self.conn.deferrable is None)
|
|
cur = self.conn.cursor()
|
|
self.conn.set_session(readonly=True, deferrable=True)
|
|
self.assert_(self.conn.deferrable is True)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
cur.execute("SHOW transaction_deferrable;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.conn.rollback()
|
|
cur.execute("SHOW transaction_deferrable;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session(deferrable=False)
|
|
self.assert_(self.conn.deferrable is False)
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
cur.execute("SHOW transaction_deferrable;")
|
|
self.assertEqual(cur.fetchone()[0], 'off')
|
|
self.conn.rollback()
|
|
|
|
@skip_after_postgres(9, 1)
|
|
def test_set_deferrable_error(self):
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
self.conn.set_session, readonly=True, deferrable=True)
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
setattr, self.conn, 'deferrable', True)
|
|
|
|
@skip_before_postgres(9, 1)
|
|
def test_setattr_deferrable(self):
|
|
cur = self.conn.cursor()
|
|
self.conn.deferrable = True
|
|
self.assert_(self.conn.deferrable is True)
|
|
cur.execute("SHOW transaction_deferrable;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.assertRaises(self.conn.ProgrammingError,
|
|
setattr, self.conn, 'deferrable', False)
|
|
self.assert_(self.conn.deferrable is True)
|
|
self.conn.rollback()
|
|
cur.execute("SHOW transaction_deferrable;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
self.conn.rollback()
|
|
|
|
cur = self.conn.cursor()
|
|
self.conn.deferrable = None
|
|
self.assert_(self.conn.deferrable is None)
|
|
cur.execute("SHOW transaction_deferrable;")
|
|
self.assertEqual(cur.fetchone()[0], 'off') # assume defined by server
|
|
self.conn.rollback()
|
|
|
|
self.conn.deferrable = False
|
|
self.assert_(self.conn.deferrable is False)
|
|
cur.execute("SHOW transaction_deferrable;")
|
|
self.assertEqual(cur.fetchone()[0], 'off')
|
|
self.conn.rollback()
|
|
|
|
def test_mixing_session_attribs(self):
|
|
cur = self.conn.cursor()
|
|
self.conn.autocommit = True
|
|
self.conn.readonly = True
|
|
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
|
|
cur.execute("SHOW default_transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
|
|
self.conn.autocommit = False
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
|
|
cur.execute("SHOW default_transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'off')
|
|
|
|
def test_idempotence_check(self):
|
|
self.conn.autocommit = False
|
|
self.conn.readonly = True
|
|
self.conn.autocommit = True
|
|
self.conn.readonly = True
|
|
|
|
cur = self.conn.cursor()
|
|
cur.execute("SHOW transaction_read_only")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
|
|
|
|
class TestEncryptPassword(ConnectingTestCase):
|
|
@skip_before_postgres(10)
|
|
def test_encrypt_password_post_9_6(self):
|
|
# MD5 algorithm
|
|
self.assertEqual(
|
|
ext.encrypt_password('psycopg2', 'ashesh', self.conn, 'md5'),
|
|
'md594839d658c28a357126f105b9cb14cfc')
|
|
|
|
# keywords
|
|
self.assertEqual(
|
|
ext.encrypt_password(
|
|
password='psycopg2', user='ashesh',
|
|
scope=self.conn, algorithm='md5'),
|
|
'md594839d658c28a357126f105b9cb14cfc')
|
|
|
|
@skip_if_crdb("password_encryption")
|
|
@skip_before_libpq(10)
|
|
@skip_before_postgres(10)
|
|
def test_encrypt_server(self):
|
|
cur = self.conn.cursor()
|
|
cur.execute("SHOW password_encryption;")
|
|
server_encryption_algorithm = cur.fetchone()[0]
|
|
|
|
enc_password = ext.encrypt_password(
|
|
'psycopg2', 'ashesh', self.conn)
|
|
|
|
if server_encryption_algorithm == 'md5':
|
|
self.assertEqual(
|
|
enc_password, 'md594839d658c28a357126f105b9cb14cfc')
|
|
elif server_encryption_algorithm == 'scram-sha-256':
|
|
self.assertEqual(enc_password[:14], 'SCRAM-SHA-256$')
|
|
|
|
self.assertEqual(
|
|
ext.encrypt_password(
|
|
'psycopg2', 'ashesh', self.conn, 'scram-sha-256'
|
|
)[:14], 'SCRAM-SHA-256$')
|
|
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc')
|
|
|
|
def test_encrypt_md5(self):
|
|
self.assertEqual(
|
|
ext.encrypt_password('psycopg2', 'ashesh', algorithm='md5'),
|
|
'md594839d658c28a357126f105b9cb14cfc')
|
|
|
|
@skip_before_libpq(10)
|
|
def test_encrypt_bad_libpq_10(self):
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc')
|
|
|
|
@skip_after_libpq(10)
|
|
def test_encrypt_bad_before_libpq_10(self):
|
|
self.assertRaises(psycopg2.NotSupportedError,
|
|
ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc')
|
|
|
|
@skip_before_libpq(10)
|
|
def test_encrypt_scram(self):
|
|
self.assert_(
|
|
ext.encrypt_password(
|
|
'psycopg2', 'ashesh', self.conn, 'scram-sha-256')
|
|
.startswith('SCRAM-SHA-256$'))
|
|
|
|
@skip_after_libpq(10)
|
|
def test_encrypt_scram_pre_10(self):
|
|
self.assertRaises(psycopg2.NotSupportedError,
|
|
ext.encrypt_password,
|
|
password='psycopg2', user='ashesh',
|
|
scope=self.conn, algorithm='scram-sha-256')
|
|
|
|
def test_bad_types(self):
|
|
self.assertRaises(TypeError, ext.encrypt_password)
|
|
self.assertRaises(TypeError, ext.encrypt_password,
|
|
'password', 42, self.conn, 'md5')
|
|
self.assertRaises(TypeError, ext.encrypt_password,
|
|
42, 'user', self.conn, 'md5')
|
|
self.assertRaises(TypeError, ext.encrypt_password,
|
|
42, 'user', 'wat', 'abc')
|
|
self.assertRaises(TypeError, ext.encrypt_password,
|
|
'password', 'user', 'wat', 42)
|
|
|
|
|
|
class AutocommitTests(ConnectingTestCase):
|
|
def test_closed(self):
|
|
self.conn.close()
|
|
self.assertRaises(psycopg2.InterfaceError,
|
|
setattr, self.conn, 'autocommit', True)
|
|
|
|
# The getter doesn't have a guard. We may change this in future
|
|
# to make it consistent with other methods; meanwhile let's just check
|
|
# it doesn't explode.
|
|
try:
|
|
self.assert_(self.conn.autocommit in (True, False))
|
|
except psycopg2.InterfaceError:
|
|
pass
|
|
|
|
def test_default_no_autocommit(self):
|
|
self.assert_(not self.conn.autocommit)
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
cur = self.conn.cursor()
|
|
cur.execute('select 1;')
|
|
self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_INTRANS)
|
|
|
|
self.conn.rollback()
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
def test_set_autocommit(self):
|
|
self.conn.autocommit = True
|
|
self.assert_(self.conn.autocommit)
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
cur = self.conn.cursor()
|
|
cur.execute('select 1;')
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
self.conn.autocommit = False
|
|
self.assert_(not self.conn.autocommit)
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
cur.execute('select 1;')
|
|
self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_INTRANS)
|
|
|
|
def test_set_intrans_error(self):
|
|
cur = self.conn.cursor()
|
|
cur.execute('select 1;')
|
|
self.assertRaises(psycopg2.ProgrammingError,
|
|
setattr, self.conn, 'autocommit', True)
|
|
|
|
def test_set_session_autocommit(self):
|
|
self.conn.set_session(autocommit=True)
|
|
self.assert_(self.conn.autocommit)
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
cur = self.conn.cursor()
|
|
cur.execute('select 1;')
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
self.conn.set_session(autocommit=False)
|
|
self.assert_(not self.conn.autocommit)
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
|
|
cur.execute('select 1;')
|
|
self.assertEqual(self.conn.status, ext.STATUS_BEGIN)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_INTRANS)
|
|
self.conn.rollback()
|
|
|
|
self.conn.set_session('serializable', readonly=True, autocommit=True)
|
|
self.assert_(self.conn.autocommit)
|
|
cur.execute('select 1;')
|
|
self.assertEqual(self.conn.status, ext.STATUS_READY)
|
|
self.assertEqual(self.conn.info.transaction_status,
|
|
ext.TRANSACTION_STATUS_IDLE)
|
|
cur.execute("SHOW transaction_isolation;")
|
|
self.assertEqual(cur.fetchone()[0], 'serializable')
|
|
cur.execute("SHOW transaction_read_only;")
|
|
self.assertEqual(cur.fetchone()[0], 'on')
|
|
|
|
|
|
class PasswordLeakTestCase(ConnectingTestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
PasswordLeakTestCase.dsn = None
|
|
|
|
class GrassingConnection(ext.connection):
|
|
"""A connection snitching the dsn away.
|
|
|
|
This connection passes the dsn to the test case class even if init
|
|
fails (e.g. connection error). Test that we mangle the dsn ok anyway.
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
try:
|
|
super(PasswordLeakTestCase.GrassingConnection, self).__init__(
|
|
*args, **kwargs)
|
|
finally:
|
|
# The connection is not initialized entirely, however the C
|
|
# code should have set the dsn, and it should have scrubbed
|
|
# the password away
|
|
PasswordLeakTestCase.dsn = self.dsn
|
|
|
|
@skip_if_crdb("connect any db")
|
|
def test_leak(self):
|
|
self.assertRaises(psycopg2.DatabaseError,
|
|
self.GrassingConnection, "dbname=nosuch password=whateva")
|
|
self.assertDsnEqual(self.dsn, "dbname=nosuch password=xxx")
|
|
|
|
@skip_before_libpq(9, 2)
|
|
def test_url_leak(self):
|
|
self.assertRaises(psycopg2.DatabaseError,
|
|
self.GrassingConnection,
|
|
"postgres://someone:whateva@localhost/nosuch")
|
|
|
|
self.assertDsnEqual(self.dsn,
|
|
"user=someone password=xxx host=localhost dbname=nosuch")
|
|
|
|
|
|
class SignalTestCase(ConnectingTestCase):
|
|
@slow
|
|
@skip_before_postgres(8, 2)
|
|
def test_bug_551_returning(self):
|
|
# Raise an exception trying to decode 'id'
|
|
self._test_bug_551(query="""
|
|
INSERT INTO test551 (num) VALUES (%s) RETURNING id
|
|
""")
|
|
|
|
@slow
|
|
def test_bug_551_no_returning(self):
|
|
# Raise an exception trying to decode 'INSERT 0 1'
|
|
self._test_bug_551(query="""
|
|
INSERT INTO test551 (num) VALUES (%s)
|
|
""")
|
|
|
|
def _test_bug_551(self, query):
|
|
script = f"""import os
|
|
import sys
|
|
import time
|
|
import signal
|
|
import warnings
|
|
import threading
|
|
|
|
# ignore wheel deprecation warning
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter('ignore')
|
|
import psycopg2
|
|
|
|
def handle_sigabort(sig, frame):
|
|
sys.exit(1)
|
|
|
|
def killer():
|
|
time.sleep(0.5)
|
|
os.kill(os.getpid(), signal.SIGABRT)
|
|
|
|
signal.signal(signal.SIGABRT, handle_sigabort)
|
|
|
|
conn = psycopg2.connect({dsn!r})
|
|
|
|
cur = conn.cursor()
|
|
|
|
cur.execute("create table test551 (id serial, num varchar(50))")
|
|
|
|
t = threading.Thread(target=killer)
|
|
t.daemon = True
|
|
t.start()
|
|
|
|
while True:
|
|
cur.execute({query!r}, ("Hello, world!",))
|
|
"""
|
|
|
|
proc = sp.Popen([sys.executable, '-c', script],
|
|
stdout=sp.PIPE, stderr=sp.PIPE)
|
|
(out, err) = proc.communicate()
|
|
self.assertNotEqual(proc.returncode, 0)
|
|
# Strip [NNN refs] from output
|
|
err = re.sub(br'\[[^\]]+\]', b'', err).strip()
|
|
self.assert_(not err, err)
|
|
|
|
|
|
class TestConnectionInfo(ConnectingTestCase):
|
|
def setUp(self):
|
|
ConnectingTestCase.setUp(self)
|
|
|
|
class BrokenConn(psycopg2.extensions.connection):
|
|
def __init__(self, *args, **kwargs):
|
|
# don't call superclass
|
|
pass
|
|
|
|
# A "broken" connection
|
|
self.bconn = self.connect(connection_factory=BrokenConn)
|
|
|
|
def test_dbname(self):
|
|
self.assert_(isinstance(self.conn.info.dbname, str))
|
|
self.assert_(self.bconn.info.dbname is None)
|
|
|
|
def test_user(self):
|
|
cur = self.conn.cursor()
|
|
cur.execute("select user")
|
|
self.assertEqual(self.conn.info.user, cur.fetchone()[0])
|
|
self.assert_(self.bconn.info.user is None)
|
|
|
|
def test_password(self):
|
|
self.assert_(isinstance(self.conn.info.password, str))
|
|
self.assert_(self.bconn.info.password is None)
|
|
|
|
def test_host(self):
|
|
expected = dbhost if dbhost else "/"
|
|
self.assertIn(expected, self.conn.info.host)
|
|
self.assert_(self.bconn.info.host is None)
|
|
|
|
def test_host_readonly(self):
|
|
with self.assertRaises(AttributeError):
|
|
self.conn.info.host = 'override'
|
|
|
|
def test_port(self):
|
|
self.assert_(isinstance(self.conn.info.port, int))
|
|
self.assert_(self.bconn.info.port is None)
|
|
|
|
def test_options(self):
|
|
self.assert_(isinstance(self.conn.info.options, str))
|
|
self.assert_(self.bconn.info.options is None)
|
|
|
|
@skip_before_libpq(9, 3)
|
|
def test_dsn_parameters(self):
|
|
d = self.conn.info.dsn_parameters
|
|
self.assert_(isinstance(d, dict))
|
|
self.assertEqual(d['dbname'], dbname) # the only param we can check reliably
|
|
self.assert_('password' not in d, d)
|
|
|
|
def test_status(self):
|
|
self.assertEqual(self.conn.info.status, 0)
|
|
self.assertEqual(self.bconn.info.status, 1)
|
|
|
|
def test_transaction_status(self):
|
|
self.assertEqual(self.conn.info.transaction_status, 0)
|
|
cur = self.conn.cursor()
|
|
cur.execute("select 1")
|
|
self.assertEqual(self.conn.info.transaction_status, 2)
|
|
self.assertEqual(self.bconn.info.transaction_status, 4)
|
|
|
|
def test_parameter_status(self):
|
|
cur = self.conn.cursor()
|
|
try:
|
|
cur.execute("show server_version")
|
|
except psycopg2.DatabaseError:
|
|
self.assertIsInstance(
|
|
self.conn.info.parameter_status('server_version'), str)
|
|
else:
|
|
self.assertEqual(
|
|
self.conn.info.parameter_status('server_version'),
|
|
cur.fetchone()[0])
|
|
|
|
self.assertIsNone(self.conn.info.parameter_status('wat'))
|
|
self.assertIsNone(self.bconn.info.parameter_status('server_version'))
|
|
|
|
def test_protocol_version(self):
|
|
self.assertEqual(self.conn.info.protocol_version, 3)
|
|
self.assertEqual(self.bconn.info.protocol_version, 0)
|
|
|
|
def test_server_version(self):
|
|
cur = self.conn.cursor()
|
|
try:
|
|
cur.execute("show server_version_num")
|
|
except psycopg2.DatabaseError:
|
|
self.assert_(isinstance(self.conn.info.server_version, int))
|
|
else:
|
|
self.assertEqual(
|
|
self.conn.info.server_version, int(cur.fetchone()[0]))
|
|
|
|
self.assertEqual(self.bconn.info.server_version, 0)
|
|
|
|
def test_error_message(self):
|
|
self.assertIsNone(self.conn.info.error_message)
|
|
self.assertIsNotNone(self.bconn.info.error_message)
|
|
|
|
cur = self.conn.cursor()
|
|
try:
|
|
cur.execute("select 1 from nosuchtable")
|
|
except psycopg2.DatabaseError:
|
|
pass
|
|
|
|
self.assert_('nosuchtable' in self.conn.info.error_message)
|
|
|
|
def test_socket(self):
|
|
self.assert_(self.conn.info.socket >= 0)
|
|
self.assert_(self.bconn.info.socket < 0)
|
|
|
|
@skip_if_crdb("backend pid")
|
|
def test_backend_pid(self):
|
|
cur = self.conn.cursor()
|
|
try:
|
|
cur.execute("select pg_backend_pid()")
|
|
except psycopg2.DatabaseError:
|
|
self.assert_(self.conn.info.backend_pid > 0)
|
|
else:
|
|
self.assertEqual(
|
|
self.conn.info.backend_pid, int(cur.fetchone()[0]))
|
|
|
|
self.assert_(self.bconn.info.backend_pid == 0)
|
|
|
|
def test_needs_password(self):
|
|
self.assertIs(self.conn.info.needs_password, False)
|
|
self.assertIs(self.bconn.info.needs_password, False)
|
|
|
|
def test_used_password(self):
|
|
self.assertIsInstance(self.conn.info.used_password, bool)
|
|
self.assertIs(self.bconn.info.used_password, False)
|
|
|
|
@skip_before_libpq(9, 5)
|
|
def test_ssl_in_use(self):
|
|
self.assertIsInstance(self.conn.info.ssl_in_use, bool)
|
|
self.assertIs(self.bconn.info.ssl_in_use, False)
|
|
|
|
@skip_after_libpq(9, 5)
|
|
def test_ssl_not_supported(self):
|
|
with self.assertRaises(psycopg2.NotSupportedError):
|
|
self.conn.info.ssl_in_use
|
|
with self.assertRaises(psycopg2.NotSupportedError):
|
|
self.conn.info.ssl_attribute_names
|
|
with self.assertRaises(psycopg2.NotSupportedError):
|
|
self.conn.info.ssl_attribute('wat')
|
|
|
|
@skip_before_libpq(9, 5)
|
|
@skip_after_libpq(16)
|
|
def test_ssl_attribute(self):
|
|
# Skip this test even if libpq built == 15, runtime == 16 (see #1619)
|
|
if ext.libpq_version() >= 160000:
|
|
return self.skipTest("libpq runtime version == %s" % ext.libpq_version())
|
|
|
|
attribs = self.conn.info.ssl_attribute_names
|
|
self.assert_(attribs)
|
|
if self.conn.info.ssl_in_use:
|
|
for attrib in attribs:
|
|
self.assertIsInstance(self.conn.info.ssl_attribute(attrib), str)
|
|
else:
|
|
for attrib in attribs:
|
|
# Behaviour changed in PostgreSQL 15
|
|
if attrib == "library":
|
|
continue
|
|
self.assertIsNone(self.conn.info.ssl_attribute(attrib))
|
|
|
|
self.assertIsNone(self.conn.info.ssl_attribute('wat'))
|
|
|
|
for attrib in attribs:
|
|
if attrib == "library":
|
|
continue
|
|
self.assertIsNone(self.bconn.info.ssl_attribute(attrib))
|
|
|
|
|
|
def test_suite():
|
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|