From 9f9da182f1d5c188c306161f8b96263ce333168b Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Fri, 28 Sep 2012 02:51:58 +0100 Subject: [PATCH] Using super() in the connection/cursor subclasses This opens to collaborative subclassing (e.g. you may want to have a logging namedtuple cursor...) --- lib/extras.py | 57 +++++++++++++++++++++++++-------------------------- 1 file changed, 28 insertions(+), 29 deletions(-) diff --git a/lib/extras.py b/lib/extras.py index eed8b326..5b45bdd1 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -54,46 +54,46 @@ class DictCursorBase(_cursor): else: raise NotImplementedError( "DictCursorBase can't be instantiated without a row factory.") - _cursor.__init__(self, *args, **kwargs) + super(DictCursorBase, self).__init__(*args, **kwargs) self._query_executed = 0 self._prefetch = 0 self.row_factory = row_factory def fetchone(self): if self._prefetch: - res = _cursor.fetchone(self) + res = super(DictCursorBase, self).fetchone() if self._query_executed: self._build_index() if not self._prefetch: - res = _cursor.fetchone(self) + res = super(DictCursorBase, self).fetchone() return res def fetchmany(self, size=None): if self._prefetch: - res = _cursor.fetchmany(self, size) + res = super(DictCursorBase, self).fetchmany(size) if self._query_executed: self._build_index() if not self._prefetch: - res = _cursor.fetchmany(self, size) + res = super(DictCursorBase, self).fetchmany(size) return res def fetchall(self): if self._prefetch: - res = _cursor.fetchall(self) + res = super(DictCursorBase, self).fetchall() if self._query_executed: self._build_index() if not self._prefetch: - res = _cursor.fetchall(self) + res = super(DictCursorBase, self).fetchall() return res def __iter__(self): if self._prefetch: - res = _cursor.__iter__(self) + res = super(DictCursorBase, self).__iter__() first = res.next() if self._query_executed: self._build_index() if not self._prefetch: - res = _cursor.__iter__(self) + res = super(DictCursorBase, self).__iter__() first = res.next() yield first @@ -105,25 +105,25 @@ class DictConnection(_connection): """A connection that uses `DictCursor` automatically.""" def cursor(self, *args, **kwargs): kwargs.setdefault('cursor_factory', DictCursor) - return _connection.cursor(self, *args, **kwargs) + return super(DictConnection, self).cursor(*args, **kwargs) class DictCursor(DictCursorBase): """A cursor that keeps a list of column name -> index mappings.""" def __init__(self, *args, **kwargs): kwargs['row_factory'] = DictRow - DictCursorBase.__init__(self, *args, **kwargs) + super(DictCursor, self).__init__(*args, **kwargs) self._prefetch = 1 def execute(self, query, vars=None): self.index = {} self._query_executed = 1 - return DictCursorBase.execute(self, query, vars) + return super(DictCursor, self).execute(query, vars) def callproc(self, procname, vars=None): self.index = {} self._query_executed = 1 - return DictCursorBase.callproc(self, procname, vars) + return super(DictCursor, self).callproc(procname, vars) def _build_index(self): if self._query_executed == 1 and self.description: @@ -196,7 +196,7 @@ class RealDictConnection(_connection): """A connection that uses `RealDictCursor` automatically.""" def cursor(self, *args, **kwargs): kwargs.setdefault('cursor_factory', RealDictCursor) - return _connection.cursor(self, *args, **kwargs) + return super(RealDictConnection, self).cursor(*args, **kwargs) class RealDictCursor(DictCursorBase): """A cursor that uses a real dict as the base type for rows. @@ -206,21 +206,20 @@ class RealDictCursor(DictCursorBase): to access database rows both as a dictionary and a list, then use the generic `DictCursor` instead of `!RealDictCursor`. """ - def __init__(self, *args, **kwargs): kwargs['row_factory'] = RealDictRow - DictCursorBase.__init__(self, *args, **kwargs) + super(RealDictCursor, self).__init__(*args, **kwargs) self._prefetch = 0 def execute(self, query, vars=None): self.column_mapping = [] self._query_executed = 1 - return DictCursorBase.execute(self, query, vars) + return super(RealDictCursor, self).execute(query, vars) def callproc(self, procname, vars=None): self.column_mapping = [] self._query_executed = 1 - return DictCursorBase.callproc(self, procname, vars) + return super(RealDictCursor, self).callproc(procname, vars) def _build_index(self): if self._query_executed == 1 and self.description: @@ -251,7 +250,7 @@ class NamedTupleConnection(_connection): """A connection that uses `NamedTupleCursor` automatically.""" def cursor(self, *args, **kwargs): kwargs.setdefault('cursor_factory', NamedTupleCursor) - return _connection.cursor(self, *args, **kwargs) + return super(NamedTupleConnection, self).cursor(*args, **kwargs) class NamedTupleCursor(_cursor): """A cursor that generates results as `~collections.namedtuple`. @@ -273,18 +272,18 @@ class NamedTupleCursor(_cursor): def execute(self, query, vars=None): self.Record = None - return _cursor.execute(self, query, vars) + return super(NamedTupleCursor, self).execute(query, vars) def executemany(self, query, vars): self.Record = None - return _cursor.executemany(self, query, vars) + return super(NamedTupleCursor, self).executemany(query, vars) def callproc(self, procname, vars=None): self.Record = None - return _cursor.callproc(self, procname, vars) + return super(NamedTupleCursor, self).callproc(procname, vars) def fetchone(self): - t = _cursor.fetchone(self) + t = super(NamedTupleCursor, self).fetchone() if t is not None: nt = self.Record if nt is None: @@ -292,21 +291,21 @@ class NamedTupleCursor(_cursor): return nt._make(t) def fetchmany(self, size=None): - ts = _cursor.fetchmany(self, size) + ts = super(NamedTupleCursor, self).fetchmany(size) nt = self.Record if nt is None: nt = self.Record = self._make_nt() return map(nt._make, ts) def fetchall(self): - ts = _cursor.fetchall(self) + ts = super(NamedTupleCursor, self).fetchall() nt = self.Record if nt is None: nt = self.Record = self._make_nt() return map(nt._make, ts) def __iter__(self): - it = _cursor.__iter__(self) + it = super(NamedTupleCursor, self).__iter__() t = it.next() nt = self.Record @@ -371,20 +370,20 @@ class LoggingConnection(_connection): def cursor(self, *args, **kwargs): self._check() kwargs.setdefault('cursor_factory', LoggingCursor) - return _connection.cursor(self, *args, **kwargs) + return super(LoggingConnection, self).cursor(*args, **kwargs) class LoggingCursor(_cursor): """A cursor that logs queries using its connection logging facilities.""" def execute(self, query, vars=None): try: - return _cursor.execute(self, query, vars) + return super(LoggingCursor, self).execute(query, vars) finally: self.connection.log(self.query, self) def callproc(self, procname, vars=None): try: - return _cursor.callproc(self, procname, vars) + return super(LoggingCursor, self).callproc(procname, vars) finally: self.connection.log(self.query, self)