Parse bytea output format ourselves instead of using the libpq

PG 9.0 uses the hex format by default, and clients < 9.0 can't parse that
format, requiring client update and great care in what is linked at runtime,
and generally giving headache to users and transitively us.
This commit is contained in:
Daniele Varrazzo 2011-03-26 10:59:27 +00:00
parent f34e44b3f4
commit 66c543b16c
3 changed files with 160 additions and 97 deletions

View File

@ -40,7 +40,7 @@ chunk_dealloc(chunkObject *self)
FORMAT_CODE_PY_SSIZE_T, FORMAT_CODE_PY_SSIZE_T,
self->base, self->len self->base, self->len
); );
PQfreemem(self->base); PyMem_Free(self->base);
Py_TYPE(self)->tp_free((PyObject *)self); Py_TYPE(self)->tp_free((PyObject *)self);
} }
@ -127,95 +127,184 @@ PyTypeObject chunkType = {
chunk_doc /* tp_doc */ chunk_doc /* tp_doc */
}; };
static char *psycopg_parse_hex(
const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout);
static char *psycopg_parse_escape(
const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout);
static PyObject * static PyObject *
typecast_BINARY_cast(const char *s, Py_ssize_t l, PyObject *curs) typecast_BINARY_cast(const char *s, Py_ssize_t l, PyObject *curs)
{ {
chunkObject *chunk = NULL; chunkObject *chunk = NULL;
PyObject *res = NULL; PyObject *res = NULL;
char *str = NULL, *buffer = NULL; char *buffer = NULL;
size_t len; Py_ssize_t len;
if (s == NULL) {Py_INCREF(Py_None); return Py_None;} if (s == NULL) {Py_INCREF(Py_None); return Py_None;}
/* PQunescapeBytea absolutely wants a 0-terminated string and we don't if (s[0] == '\\' && s[1] == 'x') {
want to copy the whole buffer, right? Wrong, but there isn't any other /* This is a buffer escaped in hex format: libpq before 9.0 can't
way <g> */ * parse it and we can't detect reliably the libpq version at runtime.
if (s[l] != '\0') { * So the only robust option is to parse it ourselves - luckily it's
if ((buffer = PyMem_Malloc(l+1)) == NULL) { * an easy format.
PyErr_NoMemory(); */
goto fail; if (NULL == (buffer = psycopg_parse_hex(s, l, &len))) {
goto exit;
} }
/* Py_ssize_t->size_t cast is safe, as long as the Py_ssize_t is
* >= 0: */
assert (l >= 0);
strncpy(buffer, s, (size_t) l);
buffer[l] = '\0';
s = buffer;
} }
str = (char*)PQunescapeBytea((unsigned char*)s, &len); else {
Dprintf("typecast_BINARY_cast: unescaped " FORMAT_CODE_SIZE_T " bytes", /* This is a buffer in the classic bytea format. So we can handle it
len); * to the PQunescapeBytea to have it parsed, rignt? ...Wrong. We
* could, but then we'd have to record whether buffer was allocated by
/* The type of the second parameter to PQunescapeBytea is size_t *, so it's * Python or by the libpq to dispose it properly. Furthermore the
* possible (especially with Python < 2.5) to get a return value too large * PQunescapeBytea interface is not the most brilliant as it wants a
* to fit into a Python container. */ * null-terminated string even if we have known its length thus
if (len > (size_t) PY_SSIZE_T_MAX) { * requiring a useless memcpy and strlen.
PyErr_SetString(PyExc_IndexError, "PG buffer too large to fit in Python" * So we'll just have our better integrated parser, let's finish this
" buffer."); * story.
goto fail; */
if (NULL == (buffer = psycopg_parse_escape(s, l, &len))) {
goto exit;
} }
/* Check the escaping was successful */
if (s[0] == '\\' && s[1] == 'x' /* input encoded in hex format */
&& str[0] == 'x' /* output resulted in an 'x' */
&& s[2] != '7' && s[3] != '8') /* input wasn't really an x (0x78) */
{
PyErr_SetString(InterfaceError,
"can't receive bytea data from server >= 9.0 with the current "
"libpq client library: please update the libpq to at least 9.0 "
"or set bytea_output to 'escape' in the server config "
"or with a query");
goto fail;
} }
chunk = (chunkObject *) PyObject_New(chunkObject, &chunkType); chunk = (chunkObject *) PyObject_New(chunkObject, &chunkType);
if (chunk == NULL) goto fail; if (chunk == NULL) goto exit;
/* **Transfer** ownership of str's memory to the chunkObject: */ /* **Transfer** ownership of buffer's memory to the chunkObject: */
chunk->base = str; chunk->base = buffer;
str = NULL; buffer = NULL;
chunk->len = (Py_ssize_t)len;
/* size_t->Py_ssize_t cast was validated above: */
chunk->len = (Py_ssize_t) len;
#if PY_MAJOR_VERSION < 3 #if PY_MAJOR_VERSION < 3
if ((res = PyBuffer_FromObject((PyObject *)chunk, 0, chunk->len)) == NULL) if ((res = PyBuffer_FromObject((PyObject *)chunk, 0, chunk->len)) == NULL)
goto fail; goto exit;
#else #else
if ((res = PyMemoryView_FromObject((PyObject*)chunk)) == NULL) if ((res = PyMemoryView_FromObject((PyObject*)chunk)) == NULL)
goto fail; goto exit;
#endif #endif
/* PyBuffer_FromObject() created a new reference. We'll release our
* reference held in 'chunk' in the 'cleanup' clause. */
goto cleanup; exit:
fail: Py_XDECREF((PyObject *)chunk);
assert (PyErr_Occurred());
if (res != NULL) {
Py_DECREF(res);
res = NULL;
}
/* Fall through to cleanup: */
cleanup:
if (chunk != NULL) {
Py_DECREF((PyObject *) chunk);
}
if (str != NULL) {
/* str's mem was allocated by PQunescapeBytea; must use PQfreemem: */
PQfreemem(str);
}
/* We allocated buffer with PyMem_Malloc; must use PyMem_Free: */
PyMem_Free(buffer); PyMem_Free(buffer);
return res; return res;
} }
static const char hex_lut[128] = {
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, -1, -1, -1, -1, -1, -1,
-1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, 10, 11, 12, 13, 14, 15, -1, -1, -1, -1, -1, -1, -1, -1, -1,
-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1,
};
/* Parse a bytea output buffer encoded in 'hex' format.
*
* the format is described in
* http://www.postgresql.org/docs/9.0/static/datatype-binary.html
*
* Parse the buffer in 'bufin', whose length is 'sizein'.
* Return a new buffer allocated by PyMem_Malloc and set 'sizeout' to its size.
* In case of error set an exception and return NULL.
*/
static char *
psycopg_parse_hex(const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout)
{
char *ret = NULL;
const char *bufend = bufin + sizein;
const char *pi = bufin + 2; /* past the \x */
char *bufout;
char *po;
po = bufout = PyMem_Malloc((sizein - 2) >> 1); /* output size upper bound */
if (NULL == bufout) {
PyErr_NoMemory();
goto exit;
}
/* Implementation note: we call this function upon database response, not
* user input (because we are parsing the output format of a buffer) so we
* don't expect errors. On bad input we reserve the right to return a bad
* output, not an error.
*/
while (pi < bufend) {
char c;
while (-1 == (c = hex_lut[*pi++ & '\x7f'])) {
if (pi >= bufend) { goto endloop; }
}
*po = c << 4;
while (-1 == (c = hex_lut[*pi++ & '\x7f'])) {
if (pi >= bufend) { goto endloop; }
}
*po++ |= c;
}
endloop:
ret = bufout;
*sizeout = po - bufout;
exit:
return ret;
}
/* Parse a bytea output buffer encoded in 'escape' format.
*
* the format is described in
* http://www.postgresql.org/docs/9.0/static/datatype-binary.html
*
* Parse the buffer in 'bufin', whose length is 'sizein'.
* Return a new buffer allocated by PyMem_Malloc and set 'sizeout' to its size.
* In case of error set an exception and return NULL.
*/
static char *
psycopg_parse_escape(const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout)
{
char *ret = NULL;
const char *bufend = bufin + sizein;
const char *pi = bufin;
char *bufout;
char *po;
po = bufout = PyMem_Malloc(sizein); /* output size upper bound */
if (NULL == bufout) {
PyErr_NoMemory();
goto exit;
}
while (pi < bufend) {
if (*pi != '\\') {
/* Unescaped char */
*po++ = *pi++;
continue;
}
if ((pi[1] >= '0' && pi[1] <= '3') &&
(pi[2] >= '0' && pi[2] <= '7') &&
(pi[3] >= '0' && pi[3] <= '7'))
{
/* Escaped octal value */
*po++ = ((pi[1] - '0') << 6) |
((pi[2] - '0') << 3) |
((pi[3] - '0'));
pi += 4;
}
else {
/* Escaped char */
*po++ = pi[1];
pi += 2;
}
}
ret = bufout;
*sizeout = po - bufout;
exit:
return ret;
}

View File

@ -140,24 +140,6 @@ def skip_if_no_namedtuple(f):
return skip_if_no_namedtuple_ return skip_if_no_namedtuple_
def skip_if_broken_hex_binary(f):
"""Decorator to detect libpq < 9.0 unable to parse bytea in hex format"""
def cope_with_hex_binary_(self):
from psycopg2 import InterfaceError
try:
return f(self)
except InterfaceError, e:
if '9.0' in str(e) and self.conn.server_version >= 90000:
return self.skipTest(
# FIXME: we are only assuming the libpq is older here,
# but we don't have a reliable way to detect the libpq
# version, not pre-9 at least.
"bytea broken with server >= 9.0, libpq < 9")
else:
raise
return cope_with_hex_binary_
def skip_if_no_iobase(f): def skip_if_no_iobase(f):
"""Skip a test if io.TextIOBase is not available.""" """Skip a test if io.TextIOBase is not available."""
def skip_if_no_iobase_(self): def skip_if_no_iobase_(self):

View File

@ -28,7 +28,7 @@ except:
pass pass
import sys import sys
import testutils import testutils
from testutils import unittest, skip_if_broken_hex_binary from testutils import unittest
from testconfig import dsn from testconfig import dsn
import psycopg2 import psycopg2
@ -116,7 +116,6 @@ class TypesBasicTests(unittest.TestCase):
s = self.execute("SELECT %s AS foo", (float("-inf"),)) s = self.execute("SELECT %s AS foo", (float("-inf"),))
self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s)) self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s))
@skip_if_broken_hex_binary
def testBinary(self): def testBinary(self):
if sys.version_info[0] < 3: if sys.version_info[0] < 3:
s = ''.join([chr(x) for x in range(256)]) s = ''.join([chr(x) for x in range(256)])
@ -143,7 +142,6 @@ class TypesBasicTests(unittest.TestCase):
b = psycopg2.Binary(bytes([])) b = psycopg2.Binary(bytes([]))
self.assertEqual(str(b), "''::bytea") self.assertEqual(str(b), "''::bytea")
@skip_if_broken_hex_binary
def testBinaryRoundTrip(self): def testBinaryRoundTrip(self):
# test to make sure buffers returned by psycopg2 are # test to make sure buffers returned by psycopg2 are
# understood by execute: # understood by execute:
@ -191,7 +189,6 @@ class TypesBasicTests(unittest.TestCase):
s = self.execute("SELECT '{}'::text AS foo") s = self.execute("SELECT '{}'::text AS foo")
self.failUnlessEqual(s, "{}") self.failUnlessEqual(s, "{}")
@skip_if_broken_hex_binary
@testutils.skip_from_python(3) @testutils.skip_from_python(3)
def testTypeRoundtripBuffer(self): def testTypeRoundtripBuffer(self):
o1 = buffer("".join(map(chr, range(256)))) o1 = buffer("".join(map(chr, range(256))))
@ -204,7 +201,6 @@ class TypesBasicTests(unittest.TestCase):
self.assertEqual(type(o1), type(o2)) self.assertEqual(type(o1), type(o2))
self.assertEqual(str(o1), str(o2)) self.assertEqual(str(o1), str(o2))
@skip_if_broken_hex_binary
@testutils.skip_from_python(3) @testutils.skip_from_python(3)
def testTypeRoundtripBufferArray(self): def testTypeRoundtripBufferArray(self):
o1 = buffer("".join(map(chr, range(256)))) o1 = buffer("".join(map(chr, range(256))))
@ -213,7 +209,6 @@ class TypesBasicTests(unittest.TestCase):
self.assertEqual(type(o1[0]), type(o2[0])) self.assertEqual(type(o1[0]), type(o2[0]))
self.assertEqual(str(o1[0]), str(o2[0])) self.assertEqual(str(o1[0]), str(o2[0]))
@skip_if_broken_hex_binary
@testutils.skip_before_python(3) @testutils.skip_before_python(3)
def testTypeRoundtripBytes(self): def testTypeRoundtripBytes(self):
o1 = bytes(range(256)) o1 = bytes(range(256))
@ -225,7 +220,6 @@ class TypesBasicTests(unittest.TestCase):
o2 = self.execute("select %s;", (o1,)) o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2)) self.assertEqual(memoryview, type(o2))
@skip_if_broken_hex_binary
@testutils.skip_before_python(3) @testutils.skip_before_python(3)
def testTypeRoundtripBytesArray(self): def testTypeRoundtripBytesArray(self):
o1 = bytes(range(256)) o1 = bytes(range(256))
@ -233,7 +227,6 @@ class TypesBasicTests(unittest.TestCase):
o2 = self.execute("select %s;", (o1,)) o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2[0])) self.assertEqual(memoryview, type(o2[0]))
@skip_if_broken_hex_binary
@testutils.skip_before_python(2, 6) @testutils.skip_before_python(2, 6)
def testAdaptBytearray(self): def testAdaptBytearray(self):
o1 = bytearray(range(256)) o1 = bytearray(range(256))
@ -258,7 +251,6 @@ class TypesBasicTests(unittest.TestCase):
else: else:
self.assertEqual(memoryview, type(o2)) self.assertEqual(memoryview, type(o2))
@skip_if_broken_hex_binary
@testutils.skip_before_python(2, 7) @testutils.skip_before_python(2, 7)
def testAdaptMemoryview(self): def testAdaptMemoryview(self):
o1 = memoryview(bytearray(range(256))) o1 = memoryview(bytearray(range(256)))