diff --git a/psycopg/typecast_binary.c b/psycopg/typecast_binary.c index 62b10829..b145b1b7 100644 --- a/psycopg/typecast_binary.c +++ b/psycopg/typecast_binary.c @@ -133,7 +133,8 @@ static char *psycopg_parse_hex( static char *psycopg_parse_escape( const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout); -static PyObject * +/* The function is not static and not hidden as we use ctypes to test it. */ +PyObject * typecast_BINARY_cast(const char *s, Py_ssize_t l, PyObject *curs) { chunkObject *chunk = NULL; diff --git a/tests/types_basic.py b/tests/types_basic.py index 83d526ff..0eb6ac48 100755 --- a/tests/types_basic.py +++ b/tests/types_basic.py @@ -327,6 +327,92 @@ class AdaptSubclassTest(unittest.TestCase): del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] +class ByteaParserTest(unittest.TestCase): + """Unit test for our bytea format parser.""" + def setUp(self): + try: + self._cast = self._import_cast() + except Exception, e: + return self.skipTest("can't test bytea parser: %s - %s" + % (e.__class__.__name__, e)) + + def _import_cast(self): + """Use ctypes to access the C function. + + Raise any sort of error: we just support this where ctypes works as + expected. + """ + import ctypes + lib = ctypes.cdll.LoadLibrary(psycopg2._psycopg.__file__) + cast = lib.typecast_BINARY_cast + cast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.py_object] + cast.restype = ctypes.py_object + return cast + + def cast(self, buffer): + """Cast a buffer from the output format""" + l = buffer and len(buffer) or 0 + rv = self._cast(buffer, l, None) + + if rv is None: + return None + + if sys.version_info[0] < 3: + return str(rv) + else: + return rv.tobytes() + + def test_null(self): + rv = self.cast(None) + self.assertEqual(rv, None) + + def test_blank(self): + rv = self.cast(b('')) + self.assertEqual(rv, b('')) + + def test_blank_hex(self): + # Reported as problematic in ticket #48 + rv = self.cast(b('\\x')) + self.assertEqual(rv, b('')) + + def test_full_hex(self, upper=False): + buf = ''.join(("%02x" % i) for i in range(256)) + if upper: buf = buf.upper() + buf = '\\x' + buf + rv = self.cast(b(buf)) + if sys.version_info[0] < 3: + self.assertEqual(rv, ''.join(map(chr, range(256)))) + else: + self.assertEqual(rv, bytes(range(256))) + + def test_full_hex_upper(self): + return self.test_full_hex(upper=True) + + def test_full_escaped_octal(self): + buf = ''.join(("\\%03o" % i) for i in range(256)) + rv = self.cast(b(buf)) + if sys.version_info[0] < 3: + self.assertEqual(rv, ''.join(map(chr, range(256)))) + else: + self.assertEqual(rv, bytes(range(256))) + + def test_escaped_mixed(self): + import string + buf = ''.join(("\\%03o" % i) for i in range(32)) + buf += string.ascii_letters + buf += ''.join('\\' + c for c in string.ascii_letters) + buf += '\\\\' + rv = self.cast(b(buf)) + if sys.version_info[0] < 3: + tgt = ''.join(map(chr, range(32))) \ + + string.ascii_letters * 2 + '\\' + else: + tgt = bytes(range(32)) + \ + (string.ascii_letters * 2 + '\\').encode('ascii') + + self.assertEqual(rv, tgt) + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)