Tests fail gracefully if tpc is supported but disabled by the server.

This commit is contained in:
Daniele Varrazzo 2010-11-06 15:06:52 +00:00
parent 645ab521f3
commit 7276c4a6b1

View File

@ -67,6 +67,32 @@ class ConnectionTests(unittest.TestCase):
self.assert_(conn.encoding in psycopg2.extensions.encodings)
def skip_if_tpc_disabled(f):
"""Skip a test if the server has tpc support disabled."""
def skip_if_tpc_disabled_(self):
cnn = self.connect()
cur = cnn.cursor()
try:
cur.execute("SHOW max_prepared_transactions;")
except psycopg2.ProgrammingError:
# Server version too old: let's die a different death
mtp = 1
else:
mtp = int(cur.fetchone()[0])
finally:
cnn.close()
if not mtp:
import warnings
warnings.warn(
"server not configured for two phase transactions. "
"set max_prepared_transactions to > 0 to run the test")
return
return f(self)
skip_if_tpc_disabled_.__name__ = f.__name__
return skip_if_tpc_disabled_
class ConnectionTwoPhaseTests(unittest.TestCase):
def setUp(self):
self.make_test_table()
@ -120,6 +146,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
def connect(self):
return psycopg2.connect(tests.dsn)
@skip_if_tpc_disabled
def test_tpc_commit(self):
cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual")
@ -161,6 +188,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(0, self.count_xacts())
self.assertEqual(1, self.count_test_records())
@skip_if_tpc_disabled
def test_tpc_commit_recovered(self):
cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual")
@ -187,6 +215,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(0, self.count_xacts())
self.assertEqual(1, self.count_test_records())
@skip_if_tpc_disabled
def test_tpc_rollback(self):
cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual")
@ -228,6 +257,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(0, self.count_xacts())
self.assertEqual(0, self.count_test_records())
@skip_if_tpc_disabled
def test_tpc_rollback_recovered(self):
cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual")
@ -266,6 +296,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
xns = cnn.tpc_recover()
self.assertEqual(psycopg2.extensions.STATUS_BEGIN, cnn.status)
@skip_if_tpc_disabled
def test_recovered_xids(self):
# insert a few test xns
cnn = self.connect()
@ -296,6 +327,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(xid.owner, owner)
self.assertEqual(xid.database, database)
@skip_if_tpc_disabled
def test_xid_encoding(self):
cnn = self.connect()
xid = cnn.xid(42, "gtrid", "bqual")
@ -308,6 +340,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
(tests.dbname,))
self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0])
@skip_if_tpc_disabled
def test_xid_roundtrip(self):
for fid, gtrid, bqual in [
(0, "", ""),
@ -331,6 +364,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
cnn.tpc_rollback(xid)
@skip_if_tpc_disabled
def test_unparsed_roundtrip(self):
for tid in [
'',
@ -383,6 +417,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
x2 = Xid.from_string('99_xxx_yyy')
self.assertEqual(str(x2), '99_xxx_yyy')
@skip_if_tpc_disabled
def test_xid_unicode(self):
cnn = self.connect()
x1 = cnn.xid(10, u'uni', u'code')
@ -395,6 +430,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual('uni', xid.gtrid)
self.assertEqual('code', xid.bqual)
@skip_if_tpc_disabled
def test_xid_unicode_unparsed(self):
# We don't expect people shooting snowmen as transaction ids,
# so if something explodes in an encode error I don't mind.