diff --git a/psycopg/xid.h b/psycopg/xid.h index 598d6a53..0c0d3bd3 100644 --- a/psycopg/xid.h +++ b/psycopg/xid.h @@ -31,9 +31,6 @@ #include "psycopg/config.h" -/* value for the format_id when the xid doesn't follow the XA standard. */ -#define XID_UNPARSED (-2) - extern HIDDEN PyTypeObject XidType; typedef struct { diff --git a/psycopg/xid_type.c b/psycopg/xid_type.c index e4d826ea..cfe0bc46 100644 --- a/psycopg/xid_type.c +++ b/psycopg/xid_type.c @@ -303,25 +303,13 @@ _xid_encode64(PyObject *s) { PyObject *base64 = NULL; PyObject *encode = NULL; - PyObject *out = NULL; PyObject *rv = NULL; if (!(base64 = PyImport_ImportModule("base64"))) { goto exit; } if (!(encode = PyObject_GetAttrString(base64, "b64encode"))) { goto exit; } - if (!(out = PyObject_CallFunctionObjArgs(encode, s, NULL))) { goto exit; } - - /* we are going to use PyString_AS_STRING on this so let's ensure it. */ - if (!PyString_Check(out)) { - PyErr_SetString(PyExc_TypeError, - "base64.b64encode didn't return a string"); - goto exit; - } - - rv = out; - out = NULL; + if (!(rv = PyObject_CallFunctionObjArgs(encode, s, NULL))) { goto exit; } exit: - Py_XDECREF(out); Py_XDECREF(encode); Py_XDECREF(base64); @@ -364,16 +352,15 @@ char * xid_get_tid(XidObject *self) { char *buf = NULL; - long format_id; Py_ssize_t bufsize = 0; PyObject *egtrid = NULL; PyObject *ebqual = NULL; + PyObject *format = NULL; + PyObject *args = NULL; PyObject *tid = NULL; - format_id = PyInt_AsLong(self->format_id); - if (-1 == format_id && PyErr_Occurred()) { goto exit; } - - if (XID_UNPARSED == format_id) { + if (Py_None == self->format_id) { + /* Unparsed xid: return the gtrid. */ bufsize = 1 + PyString_Size(self->gtrid); if (!(buf = (char *)PyMem_Malloc(bufsize))) { PyErr_NoMemory(); @@ -382,23 +369,32 @@ xid_get_tid(XidObject *self) strncpy(buf, PyString_AsString(self->gtrid), bufsize); } else { + /* XA xid: mash together the components. */ if (!(egtrid = _xid_encode64(self->gtrid))) { goto exit; } if (!(ebqual = _xid_encode64(self->bqual))) { goto exit; } - if (!(tid = PyString_FromFormat("%ld_%s_%s", - format_id, - PyString_AS_STRING(egtrid), - PyString_AS_STRING(ebqual)))) { - goto exit; - } + + /* tid = "%d_%s_%s" % (format_id, egtrid, ebqual) */ + if (!(format = PyString_FromString("%d_%s_%s"))) { goto exit; } + + if (!(args = PyTuple_New(3))) { goto exit; } + Py_INCREF(self->format_id); + PyTuple_SET_ITEM(args, 0, self->format_id); + PyTuple_SET_ITEM(args, 1, egtrid); egtrid = NULL; + PyTuple_SET_ITEM(args, 2, ebqual); ebqual = NULL; + + if (!(tid = PyString_Format(format, args))) { goto exit; } + bufsize = 1 + PyString_Size(tid); if (!(buf = (char *)PyMem_Malloc(bufsize))) { PyErr_NoMemory(); goto exit; } - strncpy(buf, PyString_AS_STRING(tid), bufsize); + strncpy(buf, PyString_AsString(tid), bufsize); } exit: + Py_XDECREF(args); + Py_XDECREF(format); Py_XDECREF(egtrid); Py_XDECREF(ebqual); Py_XDECREF(tid); @@ -505,29 +501,27 @@ static XidObject * _xid_unparsed_from_string(PyObject *str) { XidObject *xid = NULL; XidObject *rv = NULL; - PyObject *format_id = NULL; PyObject *tmp; /* fake args to work around the checks performed by the xid init */ if (!(xid = (XidObject *)PyObject_CallFunction((PyObject *)&XidType, - "iss", 0, "tmp", "tmp"))) { + "iss", 0, "", ""))) { goto exit; } - /* set xid.gtrid */ + /* set xid.gtrid = str */ tmp = xid->gtrid; Py_INCREF(str); xid->gtrid = str; Py_DECREF(tmp); - /* set xid.format_id */ - if (!(format_id = PyInt_FromLong(XID_UNPARSED))) { goto exit; } + /* set xid.format_id = None */ tmp = xid->format_id; - xid->format_id = format_id; - format_id = NULL; + Py_INCREF(Py_None); + xid->format_id = Py_None; Py_DECREF(tmp); - /* set xid.bqual */ + /* set xid.bqual = None */ tmp = xid->bqual; Py_INCREF(Py_None); xid->bqual = Py_None; @@ -538,7 +532,6 @@ _xid_unparsed_from_string(PyObject *str) { xid = NULL; exit: - Py_XDECREF(format_id); Py_XDECREF(xid); return rv; diff --git a/tests/test_connection.py b/tests/test_connection.py index 9cefddeb..57bdd8ac 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -347,7 +347,7 @@ class ConnectionTwoPhaseTests(unittest.TestCase): if xid.database == tests.dbname ] self.assertEqual(1, len(xids)) xid = xids[0] - self.assertEqual(xid.format_id, -2) + self.assertEqual(xid.format_id, None) self.assertEqual(xid.gtrid, tid) self.assertEqual(xid.bqual, None)