Fixed error in register_type()

This commit is contained in:
Federico Di Gregorio 2009-03-02 10:59:52 +01:00
parent a3ce636be0
commit 5b04203c9f
5 changed files with 62 additions and 14 deletions

View File

@ -1,5 +1,9 @@
2009-03-02 Federico Di Gregorio <fog@initd.org>
* Applied patch from Menno Smits to avoid problems with DictCursor
when the query is executed by a named cursor. Also added Menno's
tests and uniformed other DictCursor tests.
* psycopg/psycopgmodule.c: fixed unwanted exception when passing an
explicit None to register_type().

View File

@ -19,12 +19,14 @@ and classes untill a better place in the distribution is found.
import os
import time
import re as regex
try:
import logging
except:
logging = None
from psycopg2 import DATETIME, DataError
from psycopg2 import extensions as _ext
from psycopg2.extensions import cursor as _cursor
from psycopg2.extensions import connection as _connection
@ -46,26 +48,29 @@ class DictCursorBase(_cursor):
self.row_factory = row_factory
def fetchone(self):
res = _cursor.fetchone(self)
if self._query_executed:
self._build_index()
return _cursor.fetchone(self)
return res
def fetchmany(self, size=None):
res = _cursor.fetchmany(self, size)
if self._query_executed:
self._build_index()
return _cursor.fetchmany(self, size)
return res
def fetchall(self):
res = _cursor.fetchall(self)
if self._query_executed:
self._build_index()
return _cursor.fetchall(self)
return res
def next(self):
if self._query_executed:
self._build_index()
res = _cursor.fetchone(self)
if res is None:
raise StopIteration()
if self._query_executed:
self._build_index()
return res
class DictConnection(_connection):
@ -74,7 +79,7 @@ class DictConnection(_connection):
if name is None:
return _connection.cursor(self, cursor_factory=DictCursor)
else:
return _connection.cursor(self, name, cursor_factory=DictCursor)
return _connection.cursor(self, name, cursor_factory=DictCursor)
class DictCursor(DictCursorBase):
"""A cursor that keeps a list of column name -> index mappings."""
@ -302,12 +307,12 @@ try:
__str__ = getquoted
def register_uuid(oid=None):
def register_uuid(oid=None, conn_or_curs=None):
"""Create the UUID type and an uuid.UUID adapter."""
if not oid: oid = 2950
_ext.UUID = _ext.new_type((oid, ), "UUID",
lambda data, cursor: data and uuid.UUID(data) or None)
_ext.register_type(_ext.UUID)
_ext.register_type(_ext.UUID, conn_or_curs)
_ext.register_adapter(uuid.UUID, UUID_adapter)
return _ext.UUID
@ -346,12 +351,12 @@ class Inet(object):
def __str__(self):
return str(self.addr)
def register_inet(oid=None):
def register_inet(oid=None, conn_or_curs=None):
"""Create the INET type and an Inet adapter."""
if not oid: oid = 869
_ext.INET = _ext.new_type((oid, ), "INET",
lambda data, cursor: data and Inet(data) or None)
_ext.register_type(_ext.INET)
_ext.register_type(_ext.INET, conn_or_curs)
return _ext.INET

View File

@ -24,6 +24,7 @@ import test_quote
import test_connection
import test_transaction
import types_basic
import types_extras
import test_lobject
def test_suite():
@ -36,6 +37,7 @@ def test_suite():
suite.addTest(test_connection.test_suite())
suite.addTest(test_transaction.test_suite())
suite.addTest(types_basic.test_suite())
suite.addTest(types_extras.test_suite())
suite.addTest(test_lobject.test_suite())
return suite

43
tests/extras_dictcursor.py Executable file → Normal file
View File

@ -27,18 +27,55 @@ class ExtrasDictCursorTests(unittest.TestCase):
self.conn = psycopg2.connect(tests.dsn)
curs = self.conn.cursor()
curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)")
curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')")
self.conn.commit()
def tearDown(self):
self.conn.close()
def testDictCursor(self):
def testDictCursorWithPlainCursorFetchOne(self):
self._testWithPlainCursor(lambda curs: curs.fetchone())
def testDictCursorWithPlainCursorFetchMany(self):
self._testWithPlainCursor(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithPlainCursorFetchAll(self):
self._testWithPlainCursor(lambda curs: curs.fetchall()[0])
def testDictCursorWithPlainCursorIter(self):
def getter(curs):
for row in curs:
return row
self._testWithPlainCursor(getter)
def testDictCursorWithNamedCursorFetchOne(self):
self._testWithNamedCursor(lambda curs: curs.fetchone())
def testDictCursorWithNamedCursorFetchMany(self):
self._testWithNamedCursor(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithNamedCursorFetchAll(self):
self._testWithNamedCursor(lambda curs: curs.fetchall()[0])
def testDictCursorWithNamedCursorIter(self):
def getter(curs):
for row in curs:
return row
self._testWithNamedCursor(getter)
def _testWithPlainCursor(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')")
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = curs.fetchone()
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar')
def _testWithNamedCursor(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.DictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar')
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)

View File

@ -25,7 +25,7 @@ import psycopg2.extras
import tests
class TypesBasicTests(unittest.TestCase):
class TypesExtrasTests(unittest.TestCase):
"""Test that all type conversions are working."""
def setUp(self):