diff --git a/lib/pool.py b/lib/pool.py index f4c44524..e8fe53bb 100644 --- a/lib/pool.py +++ b/lib/pool.py @@ -189,36 +189,28 @@ class ThreadedConnectionPool(AbstractConnectionPool): class CachingConnectionPool(AbstractConnectionPool): """A connection pool that works with the threading module and caches connections""" + #--------------------------------------------------------------------------- def __init__(self, minconn, maxconn, lifetime = 3600, *args, **kwargs): """Initialize the threading lock.""" import threading + from datetime import datetime, timedelta + AbstractConnectionPool.__init__( self, minconn, maxconn, *args, **kwargs) self._lock = threading.Lock() self._lifetime = lifetime + #Initalize function to get expiration time. + self._expiration_time = lambda: datetime.now() + timedelta(seconds = lifetime) + # A dictionary to hold connection ID's and when they should be removed from the pool # Keys are id(connection) and vlaues are expiration time - # Storing the expiration time on the connection itself might be preferable, if possible. - from collections import OrderedDict - self._expirations = OrderedDict() - - def _connect(self, key=None): - """Create a new connection, assign it to 'key' if not None, - And assign an expiration time""" - from datetime import datetime, timedelta - conn = psycopg2.connect(*self._args, **self._kwargs) - if key is not None: - self._used[key] = conn - self._rused[id(conn)] = key - else: - self._pool.append(conn) - - #Add expiration time - self._expirations[id(conn)] = datetime.now() + timedelta(seconds = self._lifetime) - return conn + # Storing the expiration time on the connection object itself might be + # preferable, if possible. + self._expirations = {} # Override the _putconn function to put the connection back into the pool even if we are over minconn, and to run the _prune command. + #--------------------------------------------------------------------------- def _putconn(self, conn, key=None, close=False): """Put away a connection.""" if self.closed: @@ -271,15 +263,19 @@ class CachingConnectionPool(AbstractConnectionPool): # remove any expired connections from the pool self._prune() - + #--------------------------------------------------------------------------- def getconn(self, key=None): """Get a free connection and assign it to 'key' if not None.""" self._lock.acquire() try: - return self._getconn(key) + conn = self._getconn(key) + #Add expiration time + self._expirations[id(conn)] = self._expiration_time() + return conn finally: self._lock.release() + #--------------------------------------------------------------------------- def putconn(self, conn=None, key=None, close=False): """Put away an unused connection.""" self._lock.acquire() @@ -288,6 +284,7 @@ class CachingConnectionPool(AbstractConnectionPool): finally: self._lock.release() + #--------------------------------------------------------------------------- def closeall(self): """Close all connections (even the one currently in use.)""" self._lock.acquire() @@ -296,14 +293,14 @@ class CachingConnectionPool(AbstractConnectionPool): finally: self._lock.release() + #--------------------------------------------------------------------------- def _prune(self): """Remove any expired connections from the connection pool.""" - from datetime import datetime, timedelta + from datetime import datetime junk_expirations = [] for obj_id, exp_time in self._expirations.items(): - # _expirations is an ordered dict, so results should be in chronological order - if exp_time > datetime.now(): - break; + if exp_time > datetime.now(): # Not expired, move on. + continue; del_idx = None #find index of connection in _pool. May not be there if connection is in use @@ -331,7 +328,6 @@ class CachingConnectionPool(AbstractConnectionPool): if del_idx is not None: del self._pool[del_idx] - # Remove any junk expirations for item in junk_expirations: try: diff --git a/tests/test_pool.py b/tests/test_pool.py index d624490c..b63c8813 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -30,20 +30,15 @@ from testutils import (unittest, ConnectingTestCase, skip_before_postgres, from testconfig import dsn, dbname class PoolTests(ConnectingTestCase): - #---------------------------------------------------------------------- - def test_caching_pool_create_connection(self): - """Test that the _connect function creates and returns a connection""" + def test_caching_pool_get_conn(self): + """Test the call to getconn. Should just return an open connection.""" lifetime = 30 pool = psycopg_pool.CachingConnectionPool(0, 1, lifetime, dsn) + conn = pool.getconn() expected_expires = datetime.now() + timedelta(seconds = lifetime) - conn = pool._connect() #Verify we have one entry in the expiration table self.assertEqual(len(pool._expirations), 1) - - # and that the connection is actually opened - self.assertFalse(conn.closed) - actual_expires = pool._expirations[id(conn)] # there may be some slight variation between when we created the connection @@ -51,20 +46,25 @@ class PoolTests(ConnectingTestCase): # Should be negligable, however self.assertAlmostEqual(expected_expires, actual_expires, delta = timedelta(seconds = 1)) - def test_caching_pool_get_conn(self): - """Test the call to getconn. Should just return an open connection.""" - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) - conn = pool.getconn() - #make sure we got an open connection self.assertFalse(conn.closed) #Try again. We should get an error, since we only allowed one connection self.assertRaises(psycopg2.pool.PoolError, pool.getconn) + # Put the connection back, then get it again. The expiration time should increment + # If this test is consistantly failing, we may need to add a "sleep" to force + # some real time between connections, but as long as the precision of + # datetime is high enough, this should work. All we care is that new_expires + # is greater than the original expiration time + pool.putconn(conn) + conn = pool.getconn() + new_expires = pool._expirations[id(conn)] + self.assertGreater(new_expires, actual_expires) + def test_caching_pool_prune(self): """Test the prune function to make sure it closes conenctions and removes them from the pool""" - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) + pool = psycopg_pool.CachingConnectionPool(0, 3, 30, dsn) # Get a connection that we use, so it can't be pruned. sticky_conn = pool.getconn() @@ -73,16 +73,31 @@ class PoolTests(ConnectingTestCase): self.assertFalse(sticky_conn.closed) self.assertTrue(id(sticky_conn) in pool._expirations) - # create a second connection that is left in the pool, available to be pruned. - conn = pool._connect() + # create a second connection that is put back into the pool, available to be pruned. + conn = pool.getconn() + + # create a third connection that is put back into the pool, but won't be expired + new_conn = pool.getconn() + + # Put the connections back in the pool. + pool.putconn(conn) + pool.putconn(new_conn) + + # Verify that everything is in the expected state self.assertTrue(conn in pool._pool) self.assertFalse(conn in pool._used.values()) self.assertFalse(conn.closed) self.assertTrue(id(conn) in pool._expirations) - self.assertNotEqual(conn, sticky_conn) + self.assertTrue(new_conn in pool._pool) + self.assertFalse(new_conn in pool._used.values()) + self.assertFalse(new_conn.closed) + self.assertTrue(id(new_conn) in pool._expirations) - #Make the connections expire a minute ago + self.assertNotEqual(conn, sticky_conn) + self.assertNotEqual(new_conn, conn) + + #Make the connections expire a minute ago (but not new_con) old_expire = datetime.now() - timedelta(minutes = 1) pool._expirations[id(conn)] = old_expire @@ -91,7 +106,8 @@ class PoolTests(ConnectingTestCase): #prune connections pool._prune() - #make sure the unused connection is gone and closed, but the used connection isn't + # make sure the unused expired connection is gone and closed, + # but the used connection isn't self.assertFalse(conn in pool._pool) self.assertTrue(conn.closed) self.assertFalse(id(conn) in pool._expirations) @@ -99,6 +115,10 @@ class PoolTests(ConnectingTestCase): self.assertFalse(sticky_conn.closed) self.assertTrue(id(sticky_conn) in pool._expirations) + # The un-expired connection should still exist and be open + self.assertFalse(new_conn.closed) + self.assertTrue(id(new_conn) in pool._expirations) + def test_caching_pool_putconn(self): pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) conn = pool.getconn() @@ -165,4 +185,32 @@ class PoolTests(ConnectingTestCase): check_cursor.execute(SQL, (dbname, )) total_cons_after_get = check_cursor.fetchone()[0] - self.assertEqual(total_cons_after_get, total_cons) \ No newline at end of file + self.assertEqual(total_cons_after_get, total_cons) + + def test_caching_pool_closeall(self): + pool = psycopg_pool.CachingConnectionPool(0, 10, 30, dsn) + conn1 = pool.getconn() + conn2 = pool.getconn() + pool.putconn(conn2) + + self.assertEqual(len(pool._pool), 1) #1 in use, 1 put back + self.assertEqual(len(pool._expirations), 2) # We have two expirations for two connections + self.assertEqual(len(pool._used), 1) # and we have one used connection + + # Both connections should be open at this point + self.assertFalse(conn1.closed) + self.assertFalse(conn2.closed) + + pool.closeall() + + # Make sure both connections are now closed + self.assertTrue(conn1.closed) + self.assertTrue(conn2.closed) + + # Apparently the closeall command doesn't actually empty _used or _pool, + # it just blindly closes the connections. Fixit? + # self.assertEqual(len(pool._used), 0) + # self.assertEqual(len(pool._pool), 0) + + # To maintain consistancy with existing code, closeall doesn't mess with the _expirations dict either + # self.assertEqual(len(pool._expirations), 0)