XA transaction ids can be decoded from PostgreSQL transaction ids.

This commit is contained in:
Daniele Varrazzo 2010-10-11 18:27:07 +01:00
parent 6309e038e5
commit 978cac3a1b
2 changed files with 191 additions and 10 deletions

View File

@ -278,16 +278,21 @@ PyTypeObject XidType = {
*/ */
XidObject *xid_ensure(PyObject *oxid) XidObject *xid_ensure(PyObject *oxid)
{ {
/* TODO: string roundtrip. */ XidObject *rv = NULL;
if (PyObject_TypeCheck(oxid, &XidType)) { if (PyObject_TypeCheck(oxid, &XidType)) {
Py_INCREF(oxid); Py_INCREF(oxid);
return (XidObject *)oxid; rv = (XidObject *)oxid;
}
else if (PyString_Check(oxid)) {
rv = xid_from_string(oxid);
} }
else { else {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"not a valid transaction id"); "not a valid transaction id");
return NULL;
} }
return rv;
} }
@ -323,6 +328,26 @@ exit:
return rv; return rv;
} }
/* decode a base64-encoded string */
static PyObject *
_xid_decode64(PyObject *s)
{
PyObject *base64 = NULL;
PyObject *decode = NULL;
PyObject *rv = NULL;
if (!(base64 = PyImport_ImportModule("base64"))) { goto exit; }
if (!(decode = PyObject_GetAttrString(base64, "b64decode"))) { goto exit; }
if (!(rv = PyObject_CallFunctionObjArgs(decode, s, NULL))) { goto exit; }
exit:
Py_XDECREF(decode);
Py_XDECREF(base64);
return rv;
}
/* Return the PostgreSQL transaction_id for this XA xid. /* Return the PostgreSQL transaction_id for this XA xid.
* *
* PostgreSQL wants just a string, while the DBAPI supports the XA standard * PostgreSQL wants just a string, while the DBAPI supports the XA standard
@ -381,14 +406,103 @@ exit:
return buf; return buf;
} }
/* Build a Xid from a string representation.
/* Return the regex object to parse a Xid string.
* *
* If the xid is in the format generated by Psycopg, unpack the tuple into * Return a borrowed reference. */
* the struct members. Otherwise generate an "unparsed" xid.
*/ static PyObject *
XidObject * _xid_get_parse_regex(void) {
xid_from_string(PyObject *str) { static PyObject *rv;
/* TODO: currently always generates an unparsed xid. */
if (!rv) {
PyObject *re_mod = NULL;
PyObject *comp = NULL;
PyObject *regex = NULL;
Dprintf("compiling regexp to parse transaction id");
if (!(re_mod = PyImport_ImportModule("re"))) { goto exit; }
if (!(comp = PyObject_GetAttrString(re_mod, "compile"))) { goto exit; }
if (!(regex = PyObject_CallFunction(comp, "s",
"^(\\d+)_([^_]*)_([^_]*)$"))) {
goto exit;
}
/* Good, compiled. */
rv = regex;
regex = NULL;
exit:
Py_XDECREF(regex);
Py_XDECREF(comp);
Py_XDECREF(re_mod);
}
return rv;
}
/* Try to parse a Xid string representation in a Xid object.
*
*
* Return NULL + exception if parsing failed. Else a new Xid object. */
static XidObject *
_xid_parse_string(PyObject *str) {
PyObject *regex;
PyObject *m = NULL;
PyObject *group = NULL;
PyObject *item = NULL;
PyObject *format_id = NULL;
PyObject *egtrid = NULL;
PyObject *ebqual = NULL;
PyObject *gtrid = NULL;
PyObject *bqual = NULL;
XidObject *rv = NULL;
/* check if the string is a possible XA triple with a regexp */
if (!(regex = _xid_get_parse_regex())) { goto exit; }
if (!(m = PyObject_CallMethod(regex, "match", "O", str))) { goto exit; }
if (m == Py_None) {
PyErr_SetString(PyExc_ValueError, "bad xid format");
goto exit;
}
/* Extract the components from the regexp */
if (!(group = PyObject_GetAttrString(m, "group"))) { goto exit; }
if (!(item = PyObject_CallFunction(group, "i", 1))) { goto exit; }
if (!(format_id = PyObject_CallFunctionObjArgs(
(PyObject *)&PyInt_Type, item, NULL))) {
goto exit;
}
if (!(egtrid = PyObject_CallFunction(group, "i", 2))) { goto exit; }
if (!(gtrid = _xid_decode64(egtrid))) { goto exit; }
if (!(ebqual = PyObject_CallFunction(group, "i", 3))) { goto exit; }
if (!(bqual = _xid_decode64(ebqual))) { goto exit; }
/* Try to build the xid with the parsed material */
rv = (XidObject *)PyObject_CallFunctionObjArgs((PyObject *)&XidType,
format_id, gtrid, bqual, NULL);
exit:
Py_XDECREF(bqual);
Py_XDECREF(ebqual);
Py_XDECREF(gtrid);
Py_XDECREF(egtrid);
Py_XDECREF(format_id);
Py_XDECREF(item);
Py_XDECREF(group);
Py_XDECREF(m);
return rv;
}
/* Return a new Xid object representing a transaction ID not conform to
* the XA specifications. */
static XidObject *
_xid_unparsed_from_string(PyObject *str) {
XidObject *xid = NULL; XidObject *xid = NULL;
XidObject *rv = NULL; XidObject *rv = NULL;
PyObject *format_id = NULL; PyObject *format_id = NULL;
@ -430,6 +544,28 @@ exit:
return rv; return rv;
} }
/* Build a Xid from a string representation.
*
* If the xid is in the format generated by Psycopg, unpack the tuple into
* the struct members. Otherwise generate an "unparsed" xid.
*/
XidObject *
xid_from_string(PyObject *str) {
XidObject *rv;
/* Try to parse an XA triple from the string. This may fail for several
* reasons, such as the rules stated in Xid.__init__. */
rv = _xid_parse_string(str);
if (!rv) {
/* If parsing failed, treat the string as an unparsed id */
PyErr_Clear();
rv = _xid_unparsed_from_string(str);
}
return rv;
}
/* conn_tpc_recover -- return a list of pending TPC Xid */ /* conn_tpc_recover -- return a list of pending TPC Xid */
PyObject * PyObject *

View File

@ -308,6 +308,51 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
(tests.dbname,)) (tests.dbname,))
self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0]) self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0])
def test_xid_roundtrip(self):
for fid, gtrid, bqual in [
(0, "", ""),
(42, "gtrid", "bqual"),
(0x7fffffff, "x" * 64, "y" * 64),
]:
cnn = self.connect()
xid = cnn.xid(fid, gtrid, bqual)
cnn.tpc_begin(xid)
cnn.tpc_prepare()
cnn.close()
cnn = self.connect()
xids = [ xid for xid in cnn.tpc_recover()
if xid.database == tests.dbname ]
self.assertEqual(1, len(xids))
xid = xids[0]
self.assertEqual(xid.format_id, fid)
self.assertEqual(xid.gtrid, gtrid)
self.assertEqual(xid.bqual, bqual)
cnn.tpc_rollback(xid)
def test_unparsed_roundtrip(self):
for tid in [
'',
'hello, world!',
'x' * 199, # PostgreSQL's limit in transaction id length
]:
cnn = self.connect()
cnn.tpc_begin(tid)
cnn.tpc_prepare()
cnn.close()
cnn = self.connect()
xids = [ xid for xid in cnn.tpc_recover()
if xid.database == tests.dbname ]
self.assertEqual(1, len(xids))
xid = xids[0]
self.assertEqual(xid.format_id, -2)
self.assertEqual(xid.gtrid, tid)
self.assertEqual(xid.bqual, None)
cnn.tpc_rollback(xid)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)