Added inet support

This commit is contained in:
Federico Di Gregorio 2008-09-24 01:27:52 +02:00
parent 56f6001d6d
commit 4c8e80038e
3 changed files with 50 additions and 2 deletions

View File

@ -1,3 +1,7 @@
2008-09-24 Federico Di Gregorio <fog@initd.org>
* lib/extras.py: added inet support and related tests.
2008-09-23 Federico Di Gregorio <fog@initd.org>
* Applied patch from Brian Sutherland that fixes NULL

View File

@ -311,6 +311,7 @@ try:
_ext.register_adapter(uuid.UUID, UUID_adapter)
return _ext.UUID
except ImportError, e:
def register_uuid(oid=None):
"""Create the UUID type and an uuid.UUID adapter.
@ -321,4 +322,38 @@ except ImportError, e:
raise e
# a type, dbtype and adapter for PostgreSQL inet type
class Inet(object):
"""Wrap a string to allow for correct SQL-quoting of inet values.
Note that this adapter does NOT check the passed value to make
sure it really is an inet-compatible address but DOES call adapt()
on it to make sure it is impossible to execute an SQL-injection
by passing an evil value to the initializer.
"""
def __init__(self, addr):
self.addr
def prepare(self, conn):
self._conn = conn
def getquoted(self):
obj = adapt(self.addr)
if hasattr(obj, 'prepare'):
obj.prepare(self._conn)
return obj.getquoted()+"::inet"
def __str__(self):
return str(self.addr)
def register_inet(oid=None):
"""Create the INET type and an Inet adapter."""
if not oid: oid = 869
_ext.INET = _ext.new_type((oid, ), "INET",
lambda data, cursor: data and Inet(data) or None)
_ext.register_type(_ext.INET)
return _ext.INET
__all__ = [ k for k in locals().keys() if not k.startswith('_') ]

View File

@ -49,6 +49,15 @@ class TypesBasicTests(unittest.TestCase):
s = self.execute("SELECT NULL::uuid AS foo")
self.failUnless(s is None)
def testINET(self):
psycopg2.extras.register_inet()
i = "192.168.1.0/24";
s = self.execute("SELECT %s AS foo", (i,))
self.failUnless(i == s)
# must survive NULL cast to inet
s = self.execute("SELECT NULL::inet AS foo")
self.failUnless(s is None)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)