diff --git a/tests/test_connection.py b/tests/test_connection.py index 6c6dddca..f00a4c96 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -69,6 +69,7 @@ class ConnectionTests(unittest.TestCase): class ConnectionTwoPhaseTests(unittest.TestCase): def setUp(self): + self.make_test_table() self.clear_test_xacts() def tearDown(self): @@ -87,9 +88,172 @@ class ConnectionTwoPhaseTests(unittest.TestCase): cur.execute("rollback prepared %s;", (gid,)) cnn.close() + def make_test_table(self): + cnn = self.connect() + cur = cnn.cursor() + cur.execute("DROP TABLE IF EXISTS test_tpc;") + cur.execute("CREATE TABLE test_tpc (data text);") + cnn.commit() + cnn.close() + + def count_xacts(self): + """Return the number of prepared xacts currently in the test db.""" + cnn = self.connect() + cur = cnn.cursor() + cur.execute(""" + select count(*) from pg_prepared_xacts + where database = %s;""", + (tests.dbname,)) + rv = cur.fetchone()[0] + cnn.close() + return rv + + def count_test_records(self): + """Return the number of records in the test table.""" + cnn = self.connect() + cur = cnn.cursor() + cur.execute("select count(*) from test_tpc;") + rv = cur.fetchone()[0] + cnn.close() + return rv + def connect(self): return psycopg2.connect(tests.dsn) + def test_tpc_commit(self): + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + + cnn.tpc_begin(xid) + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_BEGIN) + + cur = cnn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_prepare() + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_PREPARED) + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_commit() + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(1, self.count_test_records()) + + def test_tpc_commit_one_phase(self): + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + + cnn.tpc_begin(xid) + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_BEGIN) + + cur = cnn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_1p');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_commit() + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(1, self.count_test_records()) + + def test_tpc_commit_recovered(self): + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + + cnn.tpc_begin(xid) + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_BEGIN) + + cur = cnn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_prepare() + cnn.close() + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + cnn.tpc_commit(xid) + + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(1, self.count_test_records()) + + def test_tpc_rollback(self): + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + + cnn.tpc_begin(xid) + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_BEGIN) + + cur = cnn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_prepare() + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_PREPARED) + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_rollback() + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + def test_tpc_rollback_one_phase(self): + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + + cnn.tpc_begin(xid) + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_BEGIN) + + cur = cnn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_rollback_1p');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_rollback() + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + def test_tpc_rollback_recovered(self): + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + + cnn.tpc_begin(xid) + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_BEGIN) + + cur = cnn.cursor() + cur.execute("insert into test_tpc values ('test_tpc_commit_rec');") + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn.tpc_prepare() + cnn.close() + self.assertEqual(1, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + + cnn = self.connect() + xid = cnn.xid(1, "gtrid", "bqual") + cnn.tpc_rollback(xid) + + self.assertEqual(cnn.status, psycopg2.extensions.STATUS_READY) + self.assertEqual(0, self.count_xacts()) + self.assertEqual(0, self.count_test_records()) + def test_status_after_recover(self): cnn = self.connect() self.assertEqual(psycopg2.extensions.STATUS_READY, cnn.status)