diff --git a/lib/extras.py b/lib/extras.py index b1d4d9e4..1560edd3 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -104,8 +104,7 @@ class DictCursorBase(_cursor): class DictConnection(_connection): """A connection that uses `DictCursor` automatically.""" def cursor(self, *args, **kwargs): - if 'cursor_factory' not in kwargs: - kwargs['cursor_factory'] = DictCursor + kwargs.setdefault('cursor_factory', DictCursor) return _connection.cursor(self, *args, **kwargs) class DictCursor(DictCursorBase): @@ -196,8 +195,7 @@ class DictRow(list): class RealDictConnection(_connection): """A connection that uses `RealDictCursor` automatically.""" def cursor(self, *args, **kwargs): - if 'cursor_factory' not in kwargs: - kwargs['cursor_factory'] = RealDictCursor + kwargs.setdefault('cursor_factory', RealDictCursor) return _connection.cursor(self, *args, **kwargs) class RealDictCursor(DictCursorBase): @@ -252,8 +250,7 @@ class RealDictRow(dict): class NamedTupleConnection(_connection): """A connection that uses `NamedTupleCursor` automatically.""" def cursor(self, *args, **kwargs): - if 'cursor_factory' not in kwargs: - kwargs['cursor_factory'] = NamedTupleCursor + kwargs.setdefault('cursor_factory', NamedTupleCursor) return _connection.cursor(self, *args, **kwargs) class NamedTupleCursor(_cursor): @@ -348,7 +345,7 @@ class LoggingConnection(_connection): self.log = self._logtologger else: self.log = self._logtofile - + def filter(self, msg, curs): """Filter the query before logging it. @@ -357,26 +354,24 @@ class LoggingConnection(_connection): just does nothing. """ return msg - + def _logtofile(self, msg, curs): msg = self.filter(msg, curs) if msg: self._logobj.write(msg + os.linesep) - + def _logtologger(self, msg, curs): msg = self.filter(msg, curs) if msg: self._logobj.debug(msg) - + def _check(self): if not hasattr(self, '_logobj'): raise self.ProgrammingError( "LoggingConnection object has not been initialize()d") - - def cursor(self, name=None): + + def cursor(self, *args, **kwargs): self._check() - if name is None: - return _connection.cursor(self, cursor_factory=LoggingCursor) - else: - return _connection.cursor(self, name, cursor_factory=LoggingCursor) + kwargs.setdefault('cursor_factory', LoggingCursor) + return _connection.cursor(self, *args, **kwargs) class LoggingCursor(_cursor): """A cursor that logs queries using its connection logging facilities.""" @@ -389,19 +384,19 @@ class LoggingCursor(_cursor): def callproc(self, procname, vars=None): try: - return _cursor.callproc(self, procname, vars) + return _cursor.callproc(self, procname, vars) finally: self.connection.log(self.query, self) class MinTimeLoggingConnection(LoggingConnection): """A connection that logs queries based on execution time. - + This is just an example of how to sub-class `LoggingConnection` to provide some extra filtering for the logged queries. Both the `inizialize()` and `filter()` methods are overwritten to make sure that only queries executing for more than ``mintime`` ms are logged. - + Note that this connection uses the specialized cursor `MinTimeLoggingCursor`. """ @@ -414,20 +409,17 @@ class MinTimeLoggingConnection(LoggingConnection): if t > self._mintime: return msg + os.linesep + " (execution time: %d ms)" % t - def cursor(self, name=None): - self._check() - if name is None: - return _connection.cursor(self, cursor_factory=MinTimeLoggingCursor) - else: - return _connection.cursor(self, name, cursor_factory=MinTimeLoggingCursor) - + def cursor(self, *args, **kwargs): + kwargs.setdefault('cursor_factory', MinTimeLoggingCursor) + return LoggingConnection.cursor(self, *args, **kwargs) + class MinTimeLoggingCursor(LoggingCursor): """The cursor sub-class companion to `MinTimeLoggingConnection`.""" def execute(self, query, vars=None): self.timestamp = time.time() return LoggingCursor.execute(self, query, vars) - + def callproc(self, procname, vars=None): self.timestamp = time.time() return LoggingCursor.execute(self, procname, vars)