Raise an exception if the libpq fails to decode bytea in hex format

This commit is contained in:
Daniele Varrazzo 2011-02-23 14:04:27 +00:00
parent c01a7edbf4
commit 894d3f653c
6 changed files with 65 additions and 1 deletions

2
NEWS
View File

@ -13,6 +13,8 @@ What's new in psycopg 2.4
time from the backend.
- The named cursors name can be an invalid identifier.
- 'cursor.description' is provided in named tuples if available.
- Raise a clean exception instead of returning bad data when receiving bytea
in 'hex' format and the client libpq can't parse them.
- Connections and cursors are weakly referenceable.
- Added 'b' and 't' mode to large objects: write can deal with both bytes
strings and unicode; read can return either bytes strings or decoded

View File

@ -290,6 +290,9 @@ the SQL string that would be sent to the database.
`bytea_output`__ parameter to ``escape``, either in the server
configuration or in the client session using a query such as ``SET
bytea_output TO escape;`` before trying to receive binary data.
Starting from Psycopg 2.4 this condition is detected and signaled with a
`~psycopg2.InterfaceError`.
.. __: http://www.postgresql.org/docs/9.0/static/datatype-binary.html
.. __: http://www.postgresql.org/docs/9.0/static/runtime-config-client.html#GUC-BYTEA-OUTPUT

View File

@ -166,6 +166,19 @@ typecast_BINARY_cast(const char *s, Py_ssize_t l, PyObject *curs)
goto fail;
}
/* Check the escaping was successful */
if (s[0] == '\\' && s[1] == 'x' /* input encoded in hex format */
&& str[0] == 'x' /* output resulted in an 'x' */
&& s[2] != '7' && s[3] != '8') /* input wasn't really an x (0x78) */
{
PyErr_SetString(InterfaceError,
"can't receive bytea data from server >= 9.0 with the current "
"libpq client library: please update the libpq to at least 9.0 "
"or set bytea_output to 'escape' in the server config "
"or with a query");
goto fail;
}
chunk = (chunkObject *) PyObject_New(chunkObject, &chunkType);
if (chunk == NULL) goto fail;

View File

@ -83,6 +83,10 @@ class QuotingTestCase(unittest.TestCase):
else:
res = curs.fetchone()[0].tobytes()
if res[0] in (b('x'), ord(b('x'))) and self.conn.server_version >= 90000:
return self.skipTest(
"bytea broken with server >= 9.0, libpq < 9")
self.assertEqual(res, data)
self.assert_(not self.conn.notices)

View File

@ -140,6 +140,24 @@ def skip_if_no_namedtuple(f):
return skip_if_no_namedtuple_
def skip_if_broken_hex_binary(f):
"""Decorator to detect libpq < 9.0 unable to parse bytea in hex format"""
def cope_with_hex_binary_(self):
from psycopg2 import InterfaceError
try:
return f(self)
except InterfaceError, e:
if '9.0' in str(e) and self.conn.server_version >= 90000:
return self.skipTest(
# FIXME: we are only assuming the libpq is older here,
# but we don't have a reliable way to detect the libpq
# version, not pre-9 at least.
"bytea broken with server >= 9.0, libpq < 9")
else:
raise
return cope_with_hex_binary_
def skip_if_no_iobase(f):
"""Skip a test if io.TextIOBase is not available."""
def skip_if_no_iobase_(self):

View File

@ -28,7 +28,7 @@ except:
pass
import sys
import testutils
from testutils import unittest
from testutils import unittest, skip_if_broken_hex_binary
from testconfig import dsn
import psycopg2
@ -116,6 +116,7 @@ class TypesBasicTests(unittest.TestCase):
s = self.execute("SELECT %s AS foo", (float("-inf"),))
self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s))
@skip_if_broken_hex_binary
def testBinary(self):
if sys.version_info[0] < 3:
s = ''.join([chr(x) for x in range(256)])
@ -142,6 +143,7 @@ class TypesBasicTests(unittest.TestCase):
b = psycopg2.Binary(bytes([]))
self.assertEqual(str(b), "''::bytea")
@skip_if_broken_hex_binary
def testBinaryRoundTrip(self):
# test to make sure buffers returned by psycopg2 are
# understood by execute:
@ -189,6 +191,7 @@ class TypesBasicTests(unittest.TestCase):
s = self.execute("SELECT '{}'::text AS foo")
self.failUnlessEqual(s, "{}")
@skip_if_broken_hex_binary
@testutils.skip_from_python(3)
def testTypeRoundtripBuffer(self):
o1 = buffer("".join(map(chr, range(256))))
@ -199,14 +202,18 @@ class TypesBasicTests(unittest.TestCase):
o1 = buffer("")
o2 = self.execute("select %s;", (o1,))
self.assertEqual(type(o1), type(o2))
self.assertEqual(str(o1), str(o2))
@skip_if_broken_hex_binary
@testutils.skip_from_python(3)
def testTypeRoundtripBufferArray(self):
o1 = buffer("".join(map(chr, range(256))))
o1 = [o1]
o2 = self.execute("select %s;", (o1,))
self.assertEqual(type(o1[0]), type(o2[0]))
self.assertEqual(str(o1[0]), str(o2[0]))
@skip_if_broken_hex_binary
@testutils.skip_before_python(3)
def testTypeRoundtripBytes(self):
o1 = bytes(range(256))
@ -218,6 +225,7 @@ class TypesBasicTests(unittest.TestCase):
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2))
@skip_if_broken_hex_binary
@testutils.skip_before_python(3)
def testTypeRoundtripBytesArray(self):
o1 = bytes(range(256))
@ -225,23 +233,32 @@ class TypesBasicTests(unittest.TestCase):
o2 = self.execute("select %s;", (o1,))
self.assertEqual(memoryview, type(o2[0]))
@skip_if_broken_hex_binary
@testutils.skip_before_python(2, 6)
def testAdaptBytearray(self):
o1 = bytearray(range(256))
o2 = self.execute("select %s;", (o1,))
if sys.version_info[0] < 3:
self.assertEqual(buffer, type(o2))
else:
self.assertEqual(memoryview, type(o2))
self.assertEqual(len(o1), len(o2))
for c1, c2 in zip(o1, o2):
self.assertEqual(c1, ord(c2))
# Test with an empty buffer
o1 = bytearray([])
o2 = self.execute("select %s;", (o1,))
self.assertEqual(len(o2), 0)
if sys.version_info[0] < 3:
self.assertEqual(buffer, type(o2))
else:
self.assertEqual(memoryview, type(o2))
@skip_if_broken_hex_binary
@testutils.skip_before_python(2, 7)
def testAdaptMemoryview(self):
o1 = memoryview(bytearray(range(256)))
@ -259,6 +276,13 @@ class TypesBasicTests(unittest.TestCase):
else:
self.assertEqual(memoryview, type(o2))
def testByteaHexCheckFalsePositive(self):
# the check \x -> x to detect bad bytea decode
# may be fooled if the first char is really an 'x'
o1 = psycopg2.Binary(b('x'))
o2 = self.execute("SELECT %s::bytea AS foo", (o1,))
self.assertEqual(b('x'), o2[0])
class AdaptSubclassTest(unittest.TestCase):
def test_adapt_subtype(self):