diff --git a/psycopg/xid_type.c b/psycopg/xid_type.c index dec8b80c..2a21edbf 100644 --- a/psycopg/xid_type.c +++ b/psycopg/xid_type.c @@ -231,6 +231,20 @@ exit: return rv; } + +static const char xid_from_string_doc[] = + "Create a Xid object from a string representation."; + +static PyObject * +xid_from_string_method(PyObject *cls, PyObject *args) +{ + PyObject *s = NULL; + + if (!PyArg_ParseTuple(args, "O", &s)) { return NULL; } + + return (PyObject *)xid_from_string(s); +} + static PySequenceMethods xid_sequence = { (lenfunc)xid_len, /* sq_length */ 0, /* sq_concat */ @@ -244,6 +258,14 @@ static PySequenceMethods xid_sequence = { 0, /* sq_inplace_repeat */ }; + +static struct PyMethodDef xid_methods[] = { + {"from_string", (PyCFunction)xid_from_string_method, + METH_VARARGS|METH_STATIC, xid_from_string_doc}, + {NULL} +}; + + static const char xid_doc[] = "A transaction identifier used for two phase commit."; @@ -288,7 +310,7 @@ PyTypeObject XidType = { /* Attribute descriptor and subclassing stuff */ - 0, /*tp_methods*/ + xid_methods, /*tp_methods*/ xid_members, /*tp_members*/ 0, /*tp_getset*/ 0, /*tp_base*/ @@ -327,12 +349,8 @@ XidObject *xid_ensure(PyObject *oxid) Py_INCREF(oxid); rv = (XidObject *)oxid; } - else if (PyString_Check(oxid)) { - rv = xid_from_string(oxid); - } else { - PyErr_SetString(PyExc_TypeError, - "not a valid transaction id"); + rv = xid_from_string(oxid); } return rv; @@ -572,6 +590,11 @@ XidObject * xid_from_string(PyObject *str) { XidObject *rv; + if (!PyString_Check(str)) { + PyErr_SetString(PyExc_TypeError, "not a valid transaction id"); + return NULL; + } + /* Try to parse an XA triple from the string. This may fail for several * reasons, such as the rules stated in Xid.__init__. */ rv = _xid_parse_string(str); diff --git a/tests/test_connection.py b/tests/test_connection.py index 57bdd8ac..6d35b2c4 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -353,6 +353,36 @@ class ConnectionTwoPhaseTests(unittest.TestCase): cnn.tpc_rollback(xid) + def test_xid_construction(self): + from psycopg2.extensions import Xid + + x1 = Xid(74, 'foo', 'bar') + self.assertEqual(74, x1.format_id) + self.assertEqual('foo', x1.gtrid) + self.assertEqual('bar', x1.bqual) + + def test_xid_from_string(self): + from psycopg2.extensions import Xid + + x2 = Xid.from_string('42_Z3RyaWQ=_YnF1YWw=') + self.assertEqual(42, x2.format_id) + self.assertEqual('gtrid', x2.gtrid) + self.assertEqual('bqual', x2.bqual) + + x3 = Xid.from_string('99_xxx_yyy') + self.assertEqual(None, x3.format_id) + self.assertEqual('99_xxx_yyy', x3.gtrid) + self.assertEqual(None, x3.bqual) + + def test_xid_to_string(self): + from psycopg2.extensions import Xid + + x1 = Xid.from_string('42_Z3RyaWQ=_YnF1YWw=') + self.assertEqual(str(x1), '42_Z3RyaWQ=_YnF1YWw=') + + x2 = Xid.from_string('99_xxx_yyy') + self.assertEqual(str(x2), '99_xxx_yyy') + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)