mirror of
				https://github.com/psycopg/psycopg2.git
				synced 2025-11-04 01:37:31 +03:00 
			
		
		
		
	Added tests for our own bytea parser
Because the parse function is not supposed to be exposed in Python, use ctypes to directly inspect the C function.
This commit is contained in:
		
							parent
							
								
									66c543b16c
								
							
						
					
					
						commit
						e0cd6f0f00
					
				| 
						 | 
					@ -133,7 +133,8 @@ static char *psycopg_parse_hex(
 | 
				
			||||||
static char *psycopg_parse_escape(
 | 
					static char *psycopg_parse_escape(
 | 
				
			||||||
        const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout);
 | 
					        const char *bufin, Py_ssize_t sizein, Py_ssize_t *sizeout);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
static PyObject *
 | 
					/* The function is not static and not hidden as we use ctypes to test it. */
 | 
				
			||||||
 | 
					PyObject *
 | 
				
			||||||
typecast_BINARY_cast(const char *s, Py_ssize_t l, PyObject *curs)
 | 
					typecast_BINARY_cast(const char *s, Py_ssize_t l, PyObject *curs)
 | 
				
			||||||
{
 | 
					{
 | 
				
			||||||
    chunkObject *chunk = NULL;
 | 
					    chunkObject *chunk = NULL;
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -327,6 +327,92 @@ class AdaptSubclassTest(unittest.TestCase):
 | 
				
			||||||
           del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
 | 
					           del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ByteaParserTest(unittest.TestCase):
 | 
				
			||||||
 | 
					    """Unit test for our bytea format parser."""
 | 
				
			||||||
 | 
					    def setUp(self):
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            self._cast = self._import_cast()
 | 
				
			||||||
 | 
					        except Exception, e:
 | 
				
			||||||
 | 
					            return self.skipTest("can't test bytea parser: %s - %s"
 | 
				
			||||||
 | 
					                % (e.__class__.__name__, e))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _import_cast(self):
 | 
				
			||||||
 | 
					        """Use ctypes to access the C function.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Raise any sort of error: we just support this where ctypes works as
 | 
				
			||||||
 | 
					        expected.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        import ctypes
 | 
				
			||||||
 | 
					        lib = ctypes.cdll.LoadLibrary(psycopg2._psycopg.__file__)
 | 
				
			||||||
 | 
					        cast = lib.typecast_BINARY_cast
 | 
				
			||||||
 | 
					        cast.argtypes = [ctypes.c_char_p, ctypes.c_size_t, ctypes.py_object]
 | 
				
			||||||
 | 
					        cast.restype = ctypes.py_object
 | 
				
			||||||
 | 
					        return cast
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def cast(self, buffer):
 | 
				
			||||||
 | 
					        """Cast a buffer from the output format"""
 | 
				
			||||||
 | 
					        l = buffer and len(buffer) or 0
 | 
				
			||||||
 | 
					        rv = self._cast(buffer, l, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if rv is None:
 | 
				
			||||||
 | 
					            return None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if sys.version_info[0] < 3:
 | 
				
			||||||
 | 
					            return str(rv)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return rv.tobytes()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_null(self):
 | 
				
			||||||
 | 
					        rv = self.cast(None)
 | 
				
			||||||
 | 
					        self.assertEqual(rv, None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_blank(self):
 | 
				
			||||||
 | 
					        rv = self.cast(b(''))
 | 
				
			||||||
 | 
					        self.assertEqual(rv, b(''))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_blank_hex(self):
 | 
				
			||||||
 | 
					        # Reported as problematic in ticket #48
 | 
				
			||||||
 | 
					        rv = self.cast(b('\\x'))
 | 
				
			||||||
 | 
					        self.assertEqual(rv, b(''))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_full_hex(self, upper=False):
 | 
				
			||||||
 | 
					        buf = ''.join(("%02x" % i) for i in range(256))
 | 
				
			||||||
 | 
					        if upper: buf = buf.upper()
 | 
				
			||||||
 | 
					        buf = '\\x' + buf
 | 
				
			||||||
 | 
					        rv = self.cast(b(buf))
 | 
				
			||||||
 | 
					        if sys.version_info[0] < 3:
 | 
				
			||||||
 | 
					            self.assertEqual(rv, ''.join(map(chr, range(256))))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.assertEqual(rv, bytes(range(256)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_full_hex_upper(self):
 | 
				
			||||||
 | 
					        return self.test_full_hex(upper=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_full_escaped_octal(self):
 | 
				
			||||||
 | 
					        buf = ''.join(("\\%03o" % i) for i in range(256))
 | 
				
			||||||
 | 
					        rv = self.cast(b(buf))
 | 
				
			||||||
 | 
					        if sys.version_info[0] < 3:
 | 
				
			||||||
 | 
					            self.assertEqual(rv, ''.join(map(chr, range(256))))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self.assertEqual(rv, bytes(range(256)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_escaped_mixed(self):
 | 
				
			||||||
 | 
					        import string
 | 
				
			||||||
 | 
					        buf = ''.join(("\\%03o" % i) for i in range(32))
 | 
				
			||||||
 | 
					        buf += string.ascii_letters
 | 
				
			||||||
 | 
					        buf += ''.join('\\' + c for c in string.ascii_letters)
 | 
				
			||||||
 | 
					        buf += '\\\\'
 | 
				
			||||||
 | 
					        rv = self.cast(b(buf))
 | 
				
			||||||
 | 
					        if sys.version_info[0] < 3:
 | 
				
			||||||
 | 
					            tgt = ''.join(map(chr, range(32))) \
 | 
				
			||||||
 | 
					                + string.ascii_letters * 2 + '\\'
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            tgt = bytes(range(32)) + \
 | 
				
			||||||
 | 
					                (string.ascii_letters * 2 + '\\').encode('ascii')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self.assertEqual(rv, tgt)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_suite():
 | 
					def test_suite():
 | 
				
			||||||
    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
					    return unittest.TestLoader().loadTestsFromName(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user