diff --git a/ChangeLog b/ChangeLog index 309e1fc5..e1693f84 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,7 @@ +2008-09-24 Federico Di Gregorio + + * lib/extras.py: added inet support and related tests. + 2008-09-23 Federico Di Gregorio * Applied patch from Brian Sutherland that fixes NULL diff --git a/lib/extras.py b/lib/extras.py index 67a67f20..4ca825da 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -306,11 +306,12 @@ try: """Create the UUID type and an uuid.UUID adapter.""" if not oid: oid = 2950 _ext.UUID = _ext.new_type((oid, ), "UUID", - lambda data, cursor: data and uuid.UUID(data) or None) + lambda data, cursor: data and uuid.UUID(data) or None) _ext.register_type(_ext.UUID) _ext.register_adapter(uuid.UUID, UUID_adapter) return _ext.UUID - + + except ImportError, e: def register_uuid(oid=None): """Create the UUID type and an uuid.UUID adapter. @@ -321,4 +322,38 @@ except ImportError, e: raise e +# a type, dbtype and adapter for PostgreSQL inet type + +class Inet(object): + """Wrap a string to allow for correct SQL-quoting of inet values. + + Note that this adapter does NOT check the passed value to make + sure it really is an inet-compatible address but DOES call adapt() + on it to make sure it is impossible to execute an SQL-injection + by passing an evil value to the initializer. + """ + def __init__(self, addr): + self.addr + + def prepare(self, conn): + self._conn = conn + + def getquoted(self): + obj = adapt(self.addr) + if hasattr(obj, 'prepare'): + obj.prepare(self._conn) + return obj.getquoted()+"::inet" + + def __str__(self): + return str(self.addr) + +def register_inet(oid=None): + """Create the INET type and an Inet adapter.""" + if not oid: oid = 869 + _ext.INET = _ext.new_type((oid, ), "INET", + lambda data, cursor: data and Inet(data) or None) + _ext.register_type(_ext.INET) + return _ext.INET + + __all__ = [ k for k in locals().keys() if not k.startswith('_') ] diff --git a/tests/types_extras.py b/tests/types_extras.py index 421a7294..1c0d0d7e 100644 --- a/tests/types_extras.py +++ b/tests/types_extras.py @@ -49,6 +49,15 @@ class TypesBasicTests(unittest.TestCase): s = self.execute("SELECT NULL::uuid AS foo") self.failUnless(s is None) + def testINET(self): + psycopg2.extras.register_inet() + i = "192.168.1.0/24"; + s = self.execute("SELECT %s AS foo", (i,)) + self.failUnless(i == s) + # must survive NULL cast to inet + s = self.execute("SELECT NULL::inet AS foo") + self.failUnless(s is None) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)