Fixed bug in RealDictCursor when prefetching

This commit is contained in:
Federico Di Gregorio 2009-05-09 14:44:59 +02:00
parent 3935c019fe
commit e1fae0fcac
4 changed files with 71 additions and 24 deletions

View File

@ -1,5 +1,8 @@
2009-05-10 Federico Di Gregorio <fog@initd.org> 2009-05-10 Federico Di Gregorio <fog@initd.org>
* lib/extras.py: fixed crash in fetchone() when prefetching using
a RealDictCursor.
* psycopg/cursor_ext.c: now raise correct exception when fetching * psycopg/cursor_ext.c: now raise correct exception when fetching
using a custom row factory results in an error. using a custom row factory results in an error.

View File

@ -45,32 +45,47 @@ class DictCursorBase(_cursor):
"DictCursorBase can't be instantiated without a row factory.") "DictCursorBase can't be instantiated without a row factory.")
_cursor.__init__(self, *args, **kwargs) _cursor.__init__(self, *args, **kwargs)
self._query_executed = 0 self._query_executed = 0
self._prefetch = 0
self.row_factory = row_factory self.row_factory = row_factory
def fetchone(self): def fetchone(self):
res = _cursor.fetchone(self) if self._prefetch:
res = _cursor.fetchone(self)
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchone(self)
return res return res
def fetchmany(self, size=None): def fetchmany(self, size=None):
res = _cursor.fetchmany(self, size) if self._prefetch:
res = _cursor.fetchmany(self, size)
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchmany(self, size)
return res return res
def fetchall(self): def fetchall(self):
res = _cursor.fetchall(self) if self._prefetch:
res = _cursor.fetchall(self)
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchall(self)
return res return res
def next(self): def next(self):
res = _cursor.fetchone(self) if self._prefetch:
if res is None: res = _cursor.fetchone(self)
raise StopIteration() if res is None:
raise StopIteration()
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch:
res = _cursor.fetchone(self)
if res is None:
raise StopIteration()
return res return res
class DictConnection(_connection): class DictConnection(_connection):
@ -87,22 +102,23 @@ class DictCursor(DictCursorBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['row_factory'] = DictRow kwargs['row_factory'] = DictRow
DictCursorBase.__init__(self, *args, **kwargs) DictCursorBase.__init__(self, *args, **kwargs)
self._prefetch = 1
def execute(self, query, vars=None, async=0): def execute(self, query, vars=None, async=0):
self.index = {} self.index = {}
self._query_executed = 1 self._query_executed = 1
return _cursor.execute(self, query, vars, async) return _cursor.execute(self, query, vars, async)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None):
self.index = {} self.index = {}
self._query_executed = 1 self._query_executed = 1
return _cursor.callproc(self, procname, vars) return _cursor.callproc(self, procname, vars)
def _build_index(self): def _build_index(self):
if self._query_executed == 1 and self.description: if self._query_executed == 1 and self.description:
for i in range(len(self.description)): for i in range(len(self.description)):
self.index[self.description[i][0]] = i self.index[self.description[i][0]] = i
self._query_executed = 0 self._query_executed = 0
class DictRow(list): class DictRow(list):
"""A row object that allow by-colun-name access to data.""" """A row object that allow by-colun-name access to data."""
@ -123,7 +139,7 @@ class DictRow(list):
for n, v in self._index.items(): for n, v in self._index.items():
res.append((n, list.__getitem__(self, v))) res.append((n, list.__getitem__(self, v)))
return res return res
def keys(self): def keys(self):
return self._index.keys() return self._index.keys()
@ -175,33 +191,34 @@ class RealDictCursor(DictCursorBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['row_factory'] = RealDictRow kwargs['row_factory'] = RealDictRow
DictCursorBase.__init__(self, *args, **kwargs) DictCursorBase.__init__(self, *args, **kwargs)
self._prefetch = 0
def execute(self, query, vars=None, async=0): def execute(self, query, vars=None, async=0):
self.column_mapping = [] self.column_mapping = []
self._query_executed = 1 self._query_executed = 1
return _cursor.execute(self, query, vars, async) return _cursor.execute(self, query, vars, async)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None):
self.column_mapping = [] self.column_mapping = []
self._query_executed = 1 self._query_executed = 1
return _cursor.callproc(self, procname, vars) return _cursor.callproc(self, procname, vars)
def _build_index(self): def _build_index(self):
if self._query_executed == 1 and self.description: if self._query_executed == 1 and self.description:
for item in self.description: for i in range(len(self.description)):
self.column_mapping.append(item[0]) self.column_mapping.append(self.description[i][0])
self._query_executed = 0 self._query_executed = 0
class RealDictRow(dict): class RealDictRow(dict):
__slots__ = ('_column_mapping',) __slots__ = ('_column_mapping')
def __init__(self, cursor): def __init__(self, cursor):
dict.__init__(self) dict.__init__(self)
self._column_mapping = cursor.column_mapping self._column_mapping = cursor.column_mapping
def __setitem__(self, name, value): def __setitem__(self, name, value):
if type(name) == type(0): if type(name) == int:
name = self._column_mapping[name] name = self._column_mapping[name]
return dict.__setitem__(self, name, value) return dict.__setitem__(self, name, value)
@ -211,7 +228,7 @@ class LoggingConnection(_connection):
def initialize(self, logobj): def initialize(self, logobj):
"""Initialize the connection to log to `logobj`. """Initialize the connection to log to `logobj`.
The `logobj` parameter can be an open file object or a Logger instance The `logobj` parameter can be an open file object or a Logger instance
from the standard logging module. from the standard logging module.
""" """
@ -223,7 +240,7 @@ class LoggingConnection(_connection):
def filter(self, msg, curs): def filter(self, msg, curs):
"""Filter the query before logging it. """Filter the query before logging it.
This is the method to overwrite to filter unwanted queries out of the This is the method to overwrite to filter unwanted queries out of the
log or to add some extra data to the output. The default implementation log or to add some extra data to the output. The default implementation
just does nothing. just does nothing.
@ -265,7 +282,7 @@ class LoggingCursor(_cursor):
finally: finally:
self.connection.log(self.query, self) self.connection.log(self.query, self)
class MinTimeLoggingConnection(LoggingConnection): class MinTimeLoggingConnection(LoggingConnection):
"""A connection that logs queries based on execution time. """A connection that logs queries based on execution time.
@ -279,7 +296,7 @@ class MinTimeLoggingConnection(LoggingConnection):
def initialize(self, logobj, mintime=0): def initialize(self, logobj, mintime=0):
LoggingConnection.initialize(self, logobj) LoggingConnection.initialize(self, logobj)
self._mintime = mintime self._mintime = mintime
def filter(self, msg, curs): def filter(self, msg, curs):
t = (time.time() - curs.timestamp) * 1000 t = (time.time() - curs.timestamp) * 1000
if t > self._mintime: if t > self._mintime:

View File

@ -222,7 +222,7 @@ class psycopg_build_ext(build_ext):
define_macros.append(("PG_VERSION_HEX", "0x%02X%02X%02X" % define_macros.append(("PG_VERSION_HEX", "0x%02X%02X%02X" %
(int(pgmajor), int(pgminor), int(pgpatch)))) (int(pgmajor), int(pgminor), int(pgpatch))))
except (Warning, w): except Warning, w:
if self.pg_config == self.DEFAULT_PG_CONFIG: if self.pg_config == self.DEFAULT_PG_CONFIG:
sys.stderr.write("Warning: %s" % str(w)) sys.stderr.write("Warning: %s" % str(w))
else: else:

View File

@ -47,7 +47,22 @@ class ExtrasDictCursorTests(unittest.TestCase):
for row in curs: for row in curs:
return row return row
self._testWithPlainCursor(getter) self._testWithPlainCursor(getter)
def testDictCursorWithPlainCursorRealFetchOne(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchone())
def testDictCursorWithPlainCursorRealFetchMany(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchmany(100)[0])
def testDictCursorWithPlainCursorRealFetchAll(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchall()[0])
def testDictCursorWithPlainCursorRealIter(self):
def getter(curs):
for row in curs:
return row
self._testWithPlainCursorReal(getter)
def testDictCursorWithNamedCursorFetchOne(self): def testDictCursorWithNamedCursorFetchOne(self):
self._testWithNamedCursor(lambda curs: curs.fetchone()) self._testWithNamedCursor(lambda curs: curs.fetchone())
@ -77,6 +92,18 @@ class ExtrasDictCursorTests(unittest.TestCase):
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar') self.failUnless(row[0] == 'bar')
def _testWithPlainCursorReal(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)