Using super() in the connection/cursor subclasses

This opens to collaborative subclassing (e.g. you may want to have a
logging namedtuple cursor...)
This commit is contained in:
Daniele Varrazzo 2012-09-28 02:51:58 +01:00
parent 74e6efd717
commit 387b7b6b36

View File

@ -54,46 +54,46 @@ class DictCursorBase(_cursor):
else: else:
raise NotImplementedError( raise NotImplementedError(
"DictCursorBase can't be instantiated without a row factory.") "DictCursorBase can't be instantiated without a row factory.")
_cursor.__init__(self, *args, **kwargs) super(DictCursorBase, self).__init__(*args, **kwargs)
self._query_executed = 0 self._query_executed = 0
self._prefetch = 0 self._prefetch = 0
self.row_factory = row_factory self.row_factory = row_factory
def fetchone(self): def fetchone(self):
if self._prefetch: if self._prefetch:
res = _cursor.fetchone(self) res = super(DictCursorBase, self).fetchone()
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch: if not self._prefetch:
res = _cursor.fetchone(self) res = super(DictCursorBase, self).fetchone()
return res return res
def fetchmany(self, size=None): def fetchmany(self, size=None):
if self._prefetch: if self._prefetch:
res = _cursor.fetchmany(self, size) res = super(DictCursorBase, self).fetchmany(size)
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch: if not self._prefetch:
res = _cursor.fetchmany(self, size) res = super(DictCursorBase, self).fetchmany(size)
return res return res
def fetchall(self): def fetchall(self):
if self._prefetch: if self._prefetch:
res = _cursor.fetchall(self) res = super(DictCursorBase, self).fetchall()
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch: if not self._prefetch:
res = _cursor.fetchall(self) res = super(DictCursorBase, self).fetchall()
return res return res
def __iter__(self): def __iter__(self):
if self._prefetch: if self._prefetch:
res = _cursor.__iter__(self) res = super(DictCursorBase, self).__iter__()
first = res.next() first = res.next()
if self._query_executed: if self._query_executed:
self._build_index() self._build_index()
if not self._prefetch: if not self._prefetch:
res = _cursor.__iter__(self) res = super(DictCursorBase, self).__iter__()
first = res.next() first = res.next()
yield first yield first
@ -105,25 +105,25 @@ class DictConnection(_connection):
"""A connection that uses `DictCursor` automatically.""" """A connection that uses `DictCursor` automatically."""
def cursor(self, *args, **kwargs): def cursor(self, *args, **kwargs):
kwargs.setdefault('cursor_factory', DictCursor) kwargs.setdefault('cursor_factory', DictCursor)
return _connection.cursor(self, *args, **kwargs) return super(DictConnection, self).cursor(*args, **kwargs)
class DictCursor(DictCursorBase): class DictCursor(DictCursorBase):
"""A cursor that keeps a list of column name -> index mappings.""" """A cursor that keeps a list of column name -> index mappings."""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['row_factory'] = DictRow kwargs['row_factory'] = DictRow
DictCursorBase.__init__(self, *args, **kwargs) super(DictCursor, self).__init__(*args, **kwargs)
self._prefetch = 1 self._prefetch = 1
def execute(self, query, vars=None): def execute(self, query, vars=None):
self.index = {} self.index = {}
self._query_executed = 1 self._query_executed = 1
return DictCursorBase.execute(self, query, vars) return super(DictCursor, self).execute(query, vars)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None):
self.index = {} self.index = {}
self._query_executed = 1 self._query_executed = 1
return DictCursorBase.callproc(self, procname, vars) return super(DictCursor, self).callproc(procname, vars)
def _build_index(self): def _build_index(self):
if self._query_executed == 1 and self.description: if self._query_executed == 1 and self.description:
@ -196,7 +196,7 @@ class RealDictConnection(_connection):
"""A connection that uses `RealDictCursor` automatically.""" """A connection that uses `RealDictCursor` automatically."""
def cursor(self, *args, **kwargs): def cursor(self, *args, **kwargs):
kwargs.setdefault('cursor_factory', RealDictCursor) kwargs.setdefault('cursor_factory', RealDictCursor)
return _connection.cursor(self, *args, **kwargs) return super(RealDictConnection, self).cursor(*args, **kwargs)
class RealDictCursor(DictCursorBase): class RealDictCursor(DictCursorBase):
"""A cursor that uses a real dict as the base type for rows. """A cursor that uses a real dict as the base type for rows.
@ -206,21 +206,20 @@ class RealDictCursor(DictCursorBase):
to access database rows both as a dictionary and a list, then use to access database rows both as a dictionary and a list, then use
the generic `DictCursor` instead of `!RealDictCursor`. the generic `DictCursor` instead of `!RealDictCursor`.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
kwargs['row_factory'] = RealDictRow kwargs['row_factory'] = RealDictRow
DictCursorBase.__init__(self, *args, **kwargs) super(RealDictCursor, self).__init__(*args, **kwargs)
self._prefetch = 0 self._prefetch = 0
def execute(self, query, vars=None): def execute(self, query, vars=None):
self.column_mapping = [] self.column_mapping = []
self._query_executed = 1 self._query_executed = 1
return DictCursorBase.execute(self, query, vars) return super(RealDictCursor, self).execute(query, vars)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None):
self.column_mapping = [] self.column_mapping = []
self._query_executed = 1 self._query_executed = 1
return DictCursorBase.callproc(self, procname, vars) return super(RealDictCursor, self).callproc(procname, vars)
def _build_index(self): def _build_index(self):
if self._query_executed == 1 and self.description: if self._query_executed == 1 and self.description:
@ -251,7 +250,7 @@ class NamedTupleConnection(_connection):
"""A connection that uses `NamedTupleCursor` automatically.""" """A connection that uses `NamedTupleCursor` automatically."""
def cursor(self, *args, **kwargs): def cursor(self, *args, **kwargs):
kwargs.setdefault('cursor_factory', NamedTupleCursor) kwargs.setdefault('cursor_factory', NamedTupleCursor)
return _connection.cursor(self, *args, **kwargs) return super(NamedTupleConnection, self).cursor(*args, **kwargs)
class NamedTupleCursor(_cursor): class NamedTupleCursor(_cursor):
"""A cursor that generates results as `~collections.namedtuple`. """A cursor that generates results as `~collections.namedtuple`.
@ -273,18 +272,18 @@ class NamedTupleCursor(_cursor):
def execute(self, query, vars=None): def execute(self, query, vars=None):
self.Record = None self.Record = None
return _cursor.execute(self, query, vars) return super(NamedTupleCursor, self).execute(query, vars)
def executemany(self, query, vars): def executemany(self, query, vars):
self.Record = None self.Record = None
return _cursor.executemany(self, query, vars) return super(NamedTupleCursor, self).executemany(query, vars)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None):
self.Record = None self.Record = None
return _cursor.callproc(self, procname, vars) return super(NamedTupleCursor, self).callproc(procname, vars)
def fetchone(self): def fetchone(self):
t = _cursor.fetchone(self) t = super(NamedTupleCursor, self).fetchone()
if t is not None: if t is not None:
nt = self.Record nt = self.Record
if nt is None: if nt is None:
@ -292,21 +291,21 @@ class NamedTupleCursor(_cursor):
return nt(*t) return nt(*t)
def fetchmany(self, size=None): def fetchmany(self, size=None):
ts = _cursor.fetchmany(self, size) ts = super(NamedTupleCursor, self).fetchmany(size)
nt = self.Record nt = self.Record
if nt is None: if nt is None:
nt = self.Record = self._make_nt() nt = self.Record = self._make_nt()
return [nt(*t) for t in ts] return [nt(*t) for t in ts]
def fetchall(self): def fetchall(self):
ts = _cursor.fetchall(self) ts = super(NamedTupleCursor, self).fetchall()
nt = self.Record nt = self.Record
if nt is None: if nt is None:
nt = self.Record = self._make_nt() nt = self.Record = self._make_nt()
return [nt(*t) for t in ts] return [nt(*t) for t in ts]
def __iter__(self): def __iter__(self):
it = _cursor.__iter__(self) it = super(NamedTupleCursor, self).__iter__()
t = it.next() t = it.next()
nt = self.Record nt = self.Record
@ -371,20 +370,20 @@ class LoggingConnection(_connection):
def cursor(self, *args, **kwargs): def cursor(self, *args, **kwargs):
self._check() self._check()
kwargs.setdefault('cursor_factory', LoggingCursor) kwargs.setdefault('cursor_factory', LoggingCursor)
return _connection.cursor(self, *args, **kwargs) return super(LoggingConnection, self).cursor(*args, **kwargs)
class LoggingCursor(_cursor): class LoggingCursor(_cursor):
"""A cursor that logs queries using its connection logging facilities.""" """A cursor that logs queries using its connection logging facilities."""
def execute(self, query, vars=None): def execute(self, query, vars=None):
try: try:
return _cursor.execute(self, query, vars) return super(LoggingCursor, self).execute(query, vars)
finally: finally:
self.connection.log(self.query, self) self.connection.log(self.query, self)
def callproc(self, procname, vars=None): def callproc(self, procname, vars=None):
try: try:
return _cursor.callproc(self, procname, vars) return super(LoggingCursor, self).callproc(procname, vars)
finally: finally:
self.connection.log(self.query, self) self.connection.log(self.query, self)