From 3e658c33b5ee8a75217b9843351e44bf1856e574 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Fri, 15 Oct 2010 08:27:07 +0100 Subject: [PATCH] Ensure unicode is accepted as type for transaction ids. We don't do somersaults to ensure people can use snowmen as transaction ids anyway: it would require passing the connection to xid_ensure and down below to use the correct encoding. --- psycopg/xid_type.c | 2 +- tests/test_connection.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/psycopg/xid_type.c b/psycopg/xid_type.c index 2a21edbf..626440c3 100644 --- a/psycopg/xid_type.c +++ b/psycopg/xid_type.c @@ -590,7 +590,7 @@ XidObject * xid_from_string(PyObject *str) { XidObject *rv; - if (!PyString_Check(str)) { + if (!(PyString_Check(str) || PyUnicode_Check(str))) { PyErr_SetString(PyExc_TypeError, "not a valid transaction id"); return NULL; } diff --git a/tests/test_connection.py b/tests/test_connection.py index 6d35b2c4..b83a90f3 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -383,6 +383,34 @@ class ConnectionTwoPhaseTests(unittest.TestCase): x2 = Xid.from_string('99_xxx_yyy') self.assertEqual(str(x2), '99_xxx_yyy') + def test_xid_unicode(self): + cnn = self.connect() + x1 = cnn.xid(10, u'uni', u'code') + cnn.tpc_begin(x1) + cnn.tpc_prepare() + cnn.reset() + xid = [ xid for xid in cnn.tpc_recover() + if xid.database == tests.dbname ][0] + self.assertEqual(10, xid.format_id) + self.assertEqual('uni', xid.gtrid) + self.assertEqual('code', xid.bqual) + + def test_xid_unicode_unparsed(self): + # We don't expect people shooting snowmen as transaction ids, + # so if something explodes in an encode error I don't mind. + # Let's just check uniconde is accepted as type. + cnn = self.connect() + cnn.set_client_encoding('utf8') + cnn.tpc_begin(u"transaction-id") + cnn.tpc_prepare() + cnn.reset() + + xid = [ xid for xid in cnn.tpc_recover() + if xid.database == tests.dbname ][0] + self.assertEqual(None, xid.format_id) + self.assertEqual('transaction-id', xid.gtrid) + self.assertEqual(None, xid.bqual) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)