Fixed bug in RealDictCursor when prefetching

This commit is contained in:
Federico Di Gregorio 2009-05-09 14:44:59 +02:00
parent 3935c019fe
commit e1fae0fcac
4 changed files with 71 additions and 24 deletions

View File

@ -1,5 +1,8 @@
2009-05-10 Federico Di Gregorio <fog@initd.org>
* lib/extras.py: fixed crash in fetchone() when prefetching using
a RealDictCursor.
* psycopg/cursor_ext.c: now raise correct exception when fetching
using a custom row factory results in an error.

View File

@ -45,32 +45,47 @@ class DictCursorBase(_cursor):
"DictCursorBase can't be instantiated without a row factory.")
_cursor.__init__(self, *args, **kwargs)
self._query_executed = 0
self._prefetch = 0
self.row_factory = row_factory
def fetchone(self):
res = _cursor.fetchone(self)
if self._prefetch:
res = _cursor.fetchone(self)
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.fetchone(self)
return res
def fetchmany(self, size=None):
res = _cursor.fetchmany(self, size)
if self._prefetch:
res = _cursor.fetchmany(self, size)
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.fetchmany(self, size)
return res
def fetchall(self):
res = _cursor.fetchall(self)
if self._prefetch:
res = _cursor.fetchall(self)
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.fetchall(self)
return res
def next(self):
res = _cursor.fetchone(self)
if res is None:
raise StopIteration()
if self._prefetch:
res = _cursor.fetchone(self)
if res is None:
raise StopIteration()
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.fetchone(self)
if res is None:
raise StopIteration()
return res
class DictConnection(_connection):
@ -87,22 +102,23 @@ class DictCursor(DictCursorBase):
def __init__(self, *args, **kwargs):
kwargs['row_factory'] = DictRow
DictCursorBase.__init__(self, *args, **kwargs)
self._prefetch = 1
def execute(self, query, vars=None, async=0):
self.index = {}
self._query_executed = 1
return _cursor.execute(self, query, vars, async)
def callproc(self, procname, vars=None):
self.index = {}
self._query_executed = 1
return _cursor.callproc(self, procname, vars)
return _cursor.callproc(self, procname, vars)
def _build_index(self):
if self._query_executed == 1 and self.description:
for i in range(len(self.description)):
self.index[self.description[i][0]] = i
self._query_executed = 0
self._query_executed = 0
class DictRow(list):
"""A row object that allow by-colun-name access to data."""
@ -123,7 +139,7 @@ class DictRow(list):
for n, v in self._index.items():
res.append((n, list.__getitem__(self, v)))
return res
def keys(self):
return self._index.keys()
@ -175,33 +191,34 @@ class RealDictCursor(DictCursorBase):
def __init__(self, *args, **kwargs):
kwargs['row_factory'] = RealDictRow
DictCursorBase.__init__(self, *args, **kwargs)
self._prefetch = 0
def execute(self, query, vars=None, async=0):
self.column_mapping = []
self._query_executed = 1
return _cursor.execute(self, query, vars, async)
def callproc(self, procname, vars=None):
self.column_mapping = []
self._query_executed = 1
return _cursor.callproc(self, procname, vars)
return _cursor.callproc(self, procname, vars)
def _build_index(self):
if self._query_executed == 1 and self.description:
for item in self.description:
self.column_mapping.append(item[0])
self._query_executed = 0
for i in range(len(self.description)):
self.column_mapping.append(self.description[i][0])
self._query_executed = 0
class RealDictRow(dict):
__slots__ = ('_column_mapping',)
__slots__ = ('_column_mapping')
def __init__(self, cursor):
dict.__init__(self)
self._column_mapping = cursor.column_mapping
def __setitem__(self, name, value):
if type(name) == type(0):
if type(name) == int:
name = self._column_mapping[name]
return dict.__setitem__(self, name, value)
@ -211,7 +228,7 @@ class LoggingConnection(_connection):
def initialize(self, logobj):
"""Initialize the connection to log to `logobj`.
The `logobj` parameter can be an open file object or a Logger instance
from the standard logging module.
"""
@ -223,7 +240,7 @@ class LoggingConnection(_connection):
def filter(self, msg, curs):
"""Filter the query before logging it.
This is the method to overwrite to filter unwanted queries out of the
log or to add some extra data to the output. The default implementation
just does nothing.
@ -265,7 +282,7 @@ class LoggingCursor(_cursor):
finally:
self.connection.log(self.query, self)
class MinTimeLoggingConnection(LoggingConnection):
"""A connection that logs queries based on execution time.
@ -279,7 +296,7 @@ class MinTimeLoggingConnection(LoggingConnection):
def initialize(self, logobj, mintime=0):
LoggingConnection.initialize(self, logobj)
self._mintime = mintime
def filter(self, msg, curs):
t = (time.time() - curs.timestamp) * 1000
if t > self._mintime:

View File

@ -222,7 +222,7 @@ class psycopg_build_ext(build_ext):
define_macros.append(("PG_VERSION_HEX", "0x%02X%02X%02X" %
(int(pgmajor), int(pgminor), int(pgpatch))))
except (Warning, w):
except Warning, w:
if self.pg_config == self.DEFAULT_PG_CONFIG:
sys.stderr.write("Warning: %s" % str(w))
else:

View File

@ -47,7 +47,22 @@ class ExtrasDictCursorTests(unittest.TestCase):
for row in curs:
return row
self._testWithPlainCursor(getter)
def testDictCursorWithPlainCursorRealFetchOne(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchone())
def testDictCursorWithPlainCursorRealFetchMany(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithPlainCursorRealFetchAll(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchall()[0])
def testDictCursorWithPlainCursorRealIter(self):
def getter(curs):
for row in curs:
return row
self._testWithPlainCursorReal(getter)
def testDictCursorWithNamedCursorFetchOne(self):
self._testWithNamedCursor(lambda curs: curs.fetchone())
@ -77,6 +92,18 @@ class ExtrasDictCursorTests(unittest.TestCase):
self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar')
def _testWithPlainCursorReal(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)