Fixed bug in RealDictCursor when prefetching

This commit is contained in:
Federico Di Gregorio 2009-05-09 14:44:59 +02:00
parent 3935c019fe
commit e1fae0fcac
4 changed files with 71 additions and 24 deletions

View File

@ -1,5 +1,8 @@
2009-05-10 Federico Di Gregorio <fog@initd.org> 2009-05-10 Federico Di Gregorio <fog@initd.org>
* lib/extras.py: fixed crash in fetchone() when prefetching using
a RealDictCursor.
* psycopg/cursor_ext.c: now raise correct exception when fetching * psycopg/cursor_ext.c: now raise correct exception when fetching
using a custom row factory results in an error. using a custom row factory results in an error.

View File

@ -45,32 +45,47 @@ class DictCursorBase(_cursor):
"DictCursorBase can't be instantiated without a row factory.") "DictCursorBase can't be instantiated without a row factory.")
_cursor.__init__(self, *args, **kwargs) _cursor.__init__(self, *args, **kwargs)
self._query_executed = 0 self._query_executed = 0
self._prefetch = 0
self.row_factory = row_factory self.row_factory = row_factory
def fetchone(self): def fetchone(self):
res = _cursor.fetchone(self) if self._prefetch:
res = _cursor.fetchone(self)
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchone(self)
return res return res
def fetchmany(self, size=None): def fetchmany(self, size=None):
res = _cursor.fetchmany(self, size) if self._prefetch:
res = _cursor.fetchmany(self, size)
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchmany(self, size)
return res return res
def fetchall(self): def fetchall(self):
res = _cursor.fetchall(self) if self._prefetch:
res = _cursor.fetchall(self)
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchall(self)
return res return res
def next(self): def next(self):
res = _cursor.fetchone(self) if self._prefetch:
if res is None: res = _cursor.fetchone(self)
raise StopIteration() if res is None:
raise StopIteration()
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchone(self)
if res is None:
raise StopIteration()
return res return res
class DictConnection(_connection): class DictConnection(_connection):
@ -87,6 +102,7 @@ class DictCursor(DictCursorBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['row_factory'] = DictRow kwargs['row_factory'] = DictRow
DictCursorBase.__init__(self, *args, **kwargs) DictCursorBase.__init__(self, *args, **kwargs)
self._prefetch = 1
def execute(self, query, vars=None, async=0): def execute(self, query, vars=None, async=0):
self.index = {} self.index = {}
@ -175,6 +191,7 @@ class RealDictCursor(DictCursorBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['row_factory'] = RealDictRow kwargs['row_factory'] = RealDictRow
DictCursorBase.__init__(self, *args, **kwargs) DictCursorBase.__init__(self, *args, **kwargs)
self._prefetch = 0
def execute(self, query, vars=None, async=0): def execute(self, query, vars=None, async=0):
self.column_mapping = [] self.column_mapping = []
@ -188,20 +205,20 @@ class RealDictCursor(DictCursorBase):
def _build_index(self): def _build_index(self):
if self._query_executed == 1 and self.description: if self._query_executed == 1 and self.description:
for item in self.description: for i in range(len(self.description)):
self.column_mapping.append(item[0]) self.column_mapping.append(self.description[i][0])
self._query_executed = 0 self._query_executed = 0
class RealDictRow(dict): class RealDictRow(dict):
__slots__ = ('_column_mapping',) __slots__ = ('_column_mapping')
def __init__(self, cursor): def __init__(self, cursor):
dict.__init__(self) dict.__init__(self)
self._column_mapping = cursor.column_mapping self._column_mapping = cursor.column_mapping
def __setitem__(self, name, value): def __setitem__(self, name, value):
if type(name) == type(0): if type(name) == int:
name = self._column_mapping[name] name = self._column_mapping[name]
return dict.__setitem__(self, name, value) return dict.__setitem__(self, name, value)

View File

@ -222,7 +222,7 @@ class psycopg_build_ext(build_ext):
define_macros.append(("PG_VERSION_HEX", "0x%02X%02X%02X" % define_macros.append(("PG_VERSION_HEX", "0x%02X%02X%02X" %
(int(pgmajor), int(pgminor), int(pgpatch)))) (int(pgmajor), int(pgminor), int(pgpatch))))
except (Warning, w): except Warning, w:
if self.pg_config == self.DEFAULT_PG_CONFIG: if self.pg_config == self.DEFAULT_PG_CONFIG:
sys.stderr.write("Warning: %s" % str(w)) sys.stderr.write("Warning: %s" % str(w))
else: else:

View File

@ -48,6 +48,21 @@ class ExtrasDictCursorTests(unittest.TestCase):
return row return row
self._testWithPlainCursor(getter) self._testWithPlainCursor(getter)
def testDictCursorWithPlainCursorRealFetchOne(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchone())
def testDictCursorWithPlainCursorRealFetchMany(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithPlainCursorRealFetchAll(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchall()[0])
def testDictCursorWithPlainCursorRealIter(self):
def getter(curs):
for row in curs:
return row
self._testWithPlainCursorReal(getter)
def testDictCursorWithNamedCursorFetchOne(self): def testDictCursorWithNamedCursorFetchOne(self):
self._testWithNamedCursor(lambda curs: curs.fetchone()) self._testWithNamedCursor(lambda curs: curs.fetchone())
@ -77,6 +92,18 @@ class ExtrasDictCursorTests(unittest.TestCase):
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar') self.failUnless(row[0] == 'bar')
def _testWithPlainCursorReal(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)