diff --git a/lib/pool.py b/lib/pool.py index 8d7c4afb..a5d63b66 100644 --- a/lib/pool.py +++ b/lib/pool.py @@ -32,6 +32,20 @@ class PoolError(psycopg2.Error): pass +class ConnectionContext(object): + + def __init__(self, pool): + self.pool = pool + self._conn = None + + def __enter__(self): + self._conn = self.pool.getconn() + return self._conn + + def __exit__(self, exc_type, exc_value, traceback): + self.pool.putconn(self._conn) + + class AbstractConnectionPool(object): """Generic key-based pooling code.""" @@ -57,6 +71,9 @@ class AbstractConnectionPool(object): for i in range(self.minconn): self._connect() + def __call__(self): + return ConnectionContext(self) + def _connect(self, key=None): """Create a new connection and assign it to 'key' if not None.""" conn = psycopg2.connect(*self._args, **self._kwargs) @@ -76,7 +93,7 @@ class AbstractConnectionPool(object): """Get a free connection and assign it to 'key' if not None.""" if self.closed: raise PoolError("connection pool is closed") if key is None: key = self._getkey() - + if key in self._used: return self._used[key] @@ -88,7 +105,7 @@ class AbstractConnectionPool(object): if len(self._used) == self.maxconn: raise PoolError("connection pool exhausted") return self._connect(key) - + def _putconn(self, conn, key=None, close=False): """Put away a connection.""" if self.closed: raise PoolError("connection pool is closed")