mirror of
				https://github.com/psycopg/psycopg2.git
				synced 2025-11-04 09:47:30 +03:00 
			
		
		
		
	Added tests for our own bytea parser
Because the parse function is not supposed to be exposed in Python, use ctypes to directly inspect the C function.
This commit is contained in:
		
							parent
							
								
									66c543b16c
								
							
						
					
					
						commit
						e0cd6f0f00
					
				| 
						 | 
				
			
			@ -133,7 +133,8 @@ static char *psycopg_parse_hex(
 | 
			
		|||
static char *psycopg_parse_escape(
 | 
			
		||||
        const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout);
 | 
			
		||||
 | 
			
		||||
static PyObject *
 | 
			
		||||
/* The function is not static and not hidden as we use ctypes to test it. */
 | 
			
		||||
PyObject *
 | 
			
		||||
typecast_BINARY_cast(const char *s, Py_ssize_t l, PyObject *curs)
 | 
			
		||||
{
 | 
			
		||||
    chunkObject *chunk = NULL;
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -327,6 +327,92 @@ class AdaptSubclassTest(unittest.TestCase):
 | 
			
		|||
           del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ByteaParserTest(unittest.TestCase):
 | 
			
		||||
    """Unit test for our bytea format parser."""
 | 
			
		||||
    def setUp(self):
 | 
			
		||||
        try:
 | 
			
		||||
            self._cast = self._import_cast()
 | 
			
		||||
        except Exception, e:
 | 
			
		||||
            return self.skipTest("can't test bytea parser: %s - %s"
 | 
			
		||||
                % (e.__class__.__name__, e))
 | 
			
		||||
 | 
			
		||||
    def _import_cast(self):
 | 
			
		||||
        """Use ctypes to access the C function.
 | 
			
		||||
 | 
			
		||||
        Raise any sort of error: we just support this where ctypes works as
 | 
			
		||||
        expected.
 | 
			
		||||
        """
 | 
			
		||||
        import ctypes
 | 
			
		||||
        lib = ctypes.cdll.LoadLibrary(psycopg2._psycopg.__file__)
 | 
			
		||||
        cast = lib.typecast_BINARY_cast
 | 
			
		||||
        cast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.py_object]
 | 
			
		||||
        cast.restype = ctypes.py_object
 | 
			
		||||
        return cast
 | 
			
		||||
 | 
			
		||||
    def cast(self, buffer):
 | 
			
		||||
        """Cast a buffer from the output format"""
 | 
			
		||||
        l = buffer and len(buffer) or 0
 | 
			
		||||
        rv = self._cast(buffer, l, None)
 | 
			
		||||
 | 
			
		||||
        if rv is None:
 | 
			
		||||
            return None
 | 
			
		||||
 | 
			
		||||
        if sys.version_info[0] < 3:
 | 
			
		||||
            return str(rv)
 | 
			
		||||
        else:
 | 
			
		||||
            return rv.tobytes()
 | 
			
		||||
 | 
			
		||||
    def test_null(self):
 | 
			
		||||
        rv = self.cast(None)
 | 
			
		||||
        self.assertEqual(rv, None)
 | 
			
		||||
 | 
			
		||||
    def test_blank(self):
 | 
			
		||||
        rv = self.cast(b(''))
 | 
			
		||||
        self.assertEqual(rv, b(''))
 | 
			
		||||
 | 
			
		||||
    def test_blank_hex(self):
 | 
			
		||||
        # Reported as problematic in ticket #48
 | 
			
		||||
        rv = self.cast(b('\\x'))
 | 
			
		||||
        self.assertEqual(rv, b(''))
 | 
			
		||||
 | 
			
		||||
    def test_full_hex(self, upper=False):
 | 
			
		||||
        buf = ''.join(("%02x" % i) for i in range(256))
 | 
			
		||||
        if upper: buf = buf.upper()
 | 
			
		||||
        buf = '\\x' + buf
 | 
			
		||||
        rv = self.cast(b(buf))
 | 
			
		||||
        if sys.version_info[0] < 3:
 | 
			
		||||
            self.assertEqual(rv, ''.join(map(chr, range(256))))
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertEqual(rv, bytes(range(256)))
 | 
			
		||||
 | 
			
		||||
    def test_full_hex_upper(self):
 | 
			
		||||
        return self.test_full_hex(upper=True)
 | 
			
		||||
 | 
			
		||||
    def test_full_escaped_octal(self):
 | 
			
		||||
        buf = ''.join(("\\%03o" % i) for i in range(256))
 | 
			
		||||
        rv = self.cast(b(buf))
 | 
			
		||||
        if sys.version_info[0] < 3:
 | 
			
		||||
            self.assertEqual(rv, ''.join(map(chr, range(256))))
 | 
			
		||||
        else:
 | 
			
		||||
            self.assertEqual(rv, bytes(range(256)))
 | 
			
		||||
 | 
			
		||||
    def test_escaped_mixed(self):
 | 
			
		||||
        import string
 | 
			
		||||
        buf = ''.join(("\\%03o" % i) for i in range(32))
 | 
			
		||||
        buf += string.ascii_letters
 | 
			
		||||
        buf += ''.join('\\' + c for c in string.ascii_letters)
 | 
			
		||||
        buf += '\\\\'
 | 
			
		||||
        rv = self.cast(b(buf))
 | 
			
		||||
        if sys.version_info[0] < 3:
 | 
			
		||||
            tgt = ''.join(map(chr, range(32))) \
 | 
			
		||||
                + string.ascii_letters * 2 + '\\'
 | 
			
		||||
        else:
 | 
			
		||||
            tgt = bytes(range(32)) + \
 | 
			
		||||
                (string.ascii_letters * 2 + '\\').encode('ascii')
 | 
			
		||||
 | 
			
		||||
        self.assertEqual(rv, tgt)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_suite():
 | 
			
		||||
    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user