Using super() in the connection/cursor subclasses

This opens to collaborative subclassing (e.g. you may want to have a
logging namedtuple cursor...)
This commit is contained in:
Daniele Varrazzo 2012-09-28 02:51:58 +01:00
parent 74e6efd717
commit 387b7b6b36

View File

@ -54,46 +54,46 @@ class DictCursorBase(_cursor):
else:
raise NotImplementedError(
"DictCursorBase can't be instantiated without a row factory.")
_cursor.__init__(self, *args, **kwargs)
super(DictCursorBase, self).__init__(*args, **kwargs)
self._query_executed = 0
self._prefetch = 0
self.row_factory = row_factory
def fetchone(self):
if self._prefetch:
res = _cursor.fetchone(self)
res = super(DictCursorBase, self).fetchone()
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.fetchone(self)
res = super(DictCursorBase, self).fetchone()
return res
def fetchmany(self, size=None):
if self._prefetch:
res = _cursor.fetchmany(self, size)
res = super(DictCursorBase, self).fetchmany(size)
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.fetchmany(self, size)
res = super(DictCursorBase, self).fetchmany(size)
return res
def fetchall(self):
if self._prefetch:
res = _cursor.fetchall(self)
res = super(DictCursorBase, self).fetchall()
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.fetchall(self)
res = super(DictCursorBase, self).fetchall()
return res
def __iter__(self):
if self._prefetch:
res = _cursor.__iter__(self)
res = super(DictCursorBase, self).__iter__()
first = res.next()
if self._query_executed:
self._build_index()
if not self._prefetch:
res = _cursor.__iter__(self)
res = super(DictCursorBase, self).__iter__()
first = res.next()
yield first
@ -105,25 +105,25 @@ class DictConnection(_connection):
"""A connection that uses `DictCursor` automatically."""
def cursor(self, *args, **kwargs):
kwargs.setdefault('cursor_factory', DictCursor)
return _connection.cursor(self, *args, **kwargs)
return super(DictConnection, self).cursor(*args, **kwargs)
class DictCursor(DictCursorBase):
"""A cursor that keeps a list of column name -> index mappings."""
def __init__(self, *args, **kwargs):
kwargs['row_factory'] = DictRow
DictCursorBase.__init__(self, *args, **kwargs)
super(DictCursor, self).__init__(*args, **kwargs)
self._prefetch = 1
def execute(self, query, vars=None):
self.index = {}
self._query_executed = 1
return DictCursorBase.execute(self, query, vars)
return super(DictCursor, self).execute(query, vars)
def callproc(self, procname, vars=None):
self.index = {}
self._query_executed = 1
return DictCursorBase.callproc(self, procname, vars)
return super(DictCursor, self).callproc(procname, vars)
def _build_index(self):
if self._query_executed == 1 and self.description:
@ -196,7 +196,7 @@ class RealDictConnection(_connection):
"""A connection that uses `RealDictCursor` automatically."""
def cursor(self, *args, **kwargs):
kwargs.setdefault('cursor_factory', RealDictCursor)
return _connection.cursor(self, *args, **kwargs)
return super(RealDictConnection, self).cursor(*args, **kwargs)
class RealDictCursor(DictCursorBase):
"""A cursor that uses a real dict as the base type for rows.
@ -206,21 +206,20 @@ class RealDictCursor(DictCursorBase):
to access database rows both as a dictionary and a list, then use
the generic `DictCursor` instead of `!RealDictCursor`.
"""
def __init__(self, *args, **kwargs):
kwargs['row_factory'] = RealDictRow
DictCursorBase.__init__(self, *args, **kwargs)
super(RealDictCursor, self).__init__(*args, **kwargs)
self._prefetch = 0
def execute(self, query, vars=None):
self.column_mapping = []
self._query_executed = 1
return DictCursorBase.execute(self, query, vars)
return super(RealDictCursor, self).execute(query, vars)
def callproc(self, procname, vars=None):
self.column_mapping = []
self._query_executed = 1
return DictCursorBase.callproc(self, procname, vars)
return super(RealDictCursor, self).callproc(procname, vars)
def _build_index(self):
if self._query_executed == 1 and self.description:
@ -251,7 +250,7 @@ class NamedTupleConnection(_connection):
"""A connection that uses `NamedTupleCursor` automatically."""
def cursor(self, *args, **kwargs):
kwargs.setdefault('cursor_factory', NamedTupleCursor)
return _connection.cursor(self, *args, **kwargs)
return super(NamedTupleConnection, self).cursor(*args, **kwargs)
class NamedTupleCursor(_cursor):
"""A cursor that generates results as `~collections.namedtuple`.
@ -273,18 +272,18 @@ class NamedTupleCursor(_cursor):
def execute(self, query, vars=None):
self.Record = None
return _cursor.execute(self, query, vars)
return super(NamedTupleCursor, self).execute(query, vars)
def executemany(self, query, vars):
self.Record = None
return _cursor.executemany(self, query, vars)
return super(NamedTupleCursor, self).executemany(query, vars)
def callproc(self, procname, vars=None):
self.Record = None
return _cursor.callproc(self, procname, vars)
return super(NamedTupleCursor, self).callproc(procname, vars)
def fetchone(self):
t = _cursor.fetchone(self)
t = super(NamedTupleCursor, self).fetchone()
if t is not None:
nt = self.Record
if nt is None:
@ -292,21 +291,21 @@ class NamedTupleCursor(_cursor):
return nt(*t)
def fetchmany(self, size=None):
ts = _cursor.fetchmany(self, size)
ts = super(NamedTupleCursor, self).fetchmany(size)
nt = self.Record
if nt is None:
nt = self.Record = self._make_nt()
return [nt(*t) for t in ts]
def fetchall(self):
ts = _cursor.fetchall(self)
ts = super(NamedTupleCursor, self).fetchall()
nt = self.Record
if nt is None:
nt = self.Record = self._make_nt()
return [nt(*t) for t in ts]
def __iter__(self):
it = _cursor.__iter__(self)
it = super(NamedTupleCursor, self).__iter__()
t = it.next()
nt = self.Record
@ -371,20 +370,20 @@ class LoggingConnection(_connection):
def cursor(self, *args, **kwargs):
self._check()
kwargs.setdefault('cursor_factory', LoggingCursor)
return _connection.cursor(self, *args, **kwargs)
return super(LoggingConnection, self).cursor(*args, **kwargs)
class LoggingCursor(_cursor):
"""A cursor that logs queries using its connection logging facilities."""
def execute(self, query, vars=None):
try:
return _cursor.execute(self, query, vars)
return super(LoggingCursor, self).execute(query, vars)
finally:
self.connection.log(self.query, self)
def callproc(self, procname, vars=None):
try:
return _cursor.callproc(self, procname, vars)
return super(LoggingCursor, self).callproc(procname, vars)
finally:
self.connection.log(self.query, self)