From 6e2ddd26c545f5b01c612f245a291b92d4c7ac9d Mon Sep 17 00:00:00 2001 From: Changaco Date: Mon, 23 Sep 2019 10:55:21 +0200 Subject: [PATCH] integrate connection caching into the base pool class This commit also fixes two preexisting issues: - the status of a connection wasn't checked in `getconn`, so that method could return a broken connection - the `putconn` method didn't reopen new connections if the total had dropped below `minconn` --- doc/src/pool.rst | 23 +++-- lib/pool.py | 217 ++++++++++++--------------------------------- tests/__init__.py | 2 + tests/test_pool.py | 172 +++++++++++------------------------ 4 files changed, 129 insertions(+), 285 deletions(-) diff --git a/doc/src/pool.rst b/doc/src/pool.rst index b2622155..e1cf06aa 100644 --- a/doc/src/pool.rst +++ b/doc/src/pool.rst @@ -12,7 +12,7 @@ Creating new PostgreSQL connections can be an expensive operation. This module offers a few pure Python classes implementing simple connection pooling directly in the client application. -.. class:: AbstractConnectionPool(minconn, maxconn, \*args, \*\*kwargs) +.. class:: AbstractConnectionPool(minconn, maxconn, \*args, idle_timeout=0, \*\*kwargs) Base class implementing generic key-based pooling code. @@ -20,6 +20,15 @@ directly in the client application. a maximum of about *maxconn* connections. *\*args* and *\*\*kwargs* are passed to the `~psycopg2.connect()` function. + Connections are kept in the pool for at most *idle_timeout* seconds. There + are two special values: zero means that connections are always immediately + closed upon their return to the pool; `None` means that connections are kept + indefinitely (leaving the server in charge of closing idle connections). The + current default value is zero because it replicates the behavior of previous + versions, however the default value may be changed in a future release. + + .. versionadded:: 2.9 the *idle_timeout* argument. + The following methods are expected to be implemented by subclasses: .. method:: getconn(key=None) @@ -44,6 +53,14 @@ directly in the client application. Note that all the connections are closed, including ones eventually in use by the application. + .. method:: prune + + Drop all expired connections from the pool. + + You can call this method periodically to clean up the pool. + + .. versionadded:: 2.9 + The following classes are `AbstractConnectionPool` subclasses ready to be used. @@ -58,7 +75,3 @@ be used. .. autoclass:: ThreadedConnectionPool .. note:: This pool class can be safely used in multi-threaded applications. - -.. autoclass:: CachingConnectionPool - - .. note:: Expired connections are cleaned up on any call to putconn. diff --git a/lib/pool.py b/lib/pool.py index cce4a790..1f1b44f1 100644 --- a/lib/pool.py +++ b/lib/pool.py @@ -24,6 +24,12 @@ This module implements thread-safe (and not) connection pools. # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. +try: + from time import process_time +except ImportError: + # For Python < 3.3 + from time import clock as process_time + import psycopg2 from psycopg2 import extensions as _ext @@ -44,6 +50,7 @@ class AbstractConnectionPool(object): """ self.minconn = int(minconn) self.maxconn = int(maxconn) + self.idle_timeout = kwargs.pop('idle_timeout', 0) self.closed = False self._args = args @@ -53,6 +60,7 @@ class AbstractConnectionPool(object): self._used = {} self._rused = {} # id(conn) -> key map self._keys = 0 + self._return_times = {} for i in range(self.minconn): self._connect() @@ -64,6 +72,7 @@ class AbstractConnectionPool(object): self._used[key] = conn self._rused[id(conn)] = key else: + self._return_times[id(conn)] = process_time() self._pool.append(conn) return conn @@ -82,14 +91,26 @@ class AbstractConnectionPool(object): if key in self._used: return self._used[key] - if self._pool: - self._used[key] = conn = self._pool.pop() - self._rused[id(conn)] = key - return conn - else: - if len(self._used) == self.maxconn: - raise PoolError("connection pool exhausted") - return self._connect(key) + while True: + try: + conn = self._pool.pop() + except IndexError: + if len(self._used) >= self.maxconn: + raise PoolError("connection pool exhausted") + conn = self._connect(key) + else: + idle_since = self._return_times.pop(id(conn), 0) + close = ( + conn.info.transaction_status != _ext.TRANSACTION_STATUS_IDLE or + self.idle_timeout and idle_since < (process_time() - self.idle_timeout) + ) + if close: + conn.close() + continue + break + self._used[key] = conn + self._rused[id(conn)] = key + return conn def _putconn(self, conn, key=None, close=False): """Put away a connection.""" @@ -101,7 +122,9 @@ class AbstractConnectionPool(object): if key is None: raise PoolError("trying to put unkeyed connection") - if len(self._pool) < self.minconn and not close: + if close or self.idle_timeout == 0 and len(self._pool) >= self.minconn: + conn.close() + else: # Return the connection into a consistent state before putting # it back into the pool if not conn.closed: @@ -112,13 +135,13 @@ class AbstractConnectionPool(object): elif status != _ext.TRANSACTION_STATUS_IDLE: # connection in error or in transaction conn.rollback() + self._return_times[id(conn)] = process_time() self._pool.append(conn) else: # regular idle connection + self._return_times[id(conn)] = process_time() self._pool.append(conn) # If the connection is closed, we just discard it. - else: - conn.close() # here we check for the presence of key because it can happen that a # thread tries to put back a connection after a call to close @@ -126,6 +149,10 @@ class AbstractConnectionPool(object): del self._used[key] del self._rused[id(conn)] + # Open new connections if we've dropped below minconn. + while (len(self._pool) + len(self._used)) < self.minconn: + self._connect() + def _closeall(self): """Close all connections. @@ -142,6 +169,20 @@ class AbstractConnectionPool(object): pass self.closed = True + def _prune(self): + """Drop all expired connections from the pool.""" + if self.idle_timeout is None: + return + threshold = process_time() - self.idle_timeout + for conn in list(self._pool): + if self._return_times.get(id(conn), 0) < threshold: + try: + self._pool.remove(conn) + except ValueError: + continue + self._return_times.pop(id(conn), None) + conn.close() + class SimpleConnectionPool(AbstractConnectionPool): """A connection pool that can't be shared across different threads.""" @@ -149,6 +190,7 @@ class SimpleConnectionPool(AbstractConnectionPool): getconn = AbstractConnectionPool._getconn putconn = AbstractConnectionPool._putconn closeall = AbstractConnectionPool._closeall + prune = AbstractConnectionPool._prune class ThreadedConnectionPool(AbstractConnectionPool): @@ -185,157 +227,10 @@ class ThreadedConnectionPool(AbstractConnectionPool): finally: self._lock.release() - -class CachingConnectionPool(AbstractConnectionPool): - """A connection pool that works with the threading module and caches connections""" - - #--------------------------------------------------------------------------- - def __init__(self, minconn, maxconn, lifetime = 3600, *args, **kwargs): - """Initialize the threading lock.""" - import threading - from datetime import datetime, timedelta - - AbstractConnectionPool.__init__( - self, minconn, maxconn, *args, **kwargs) - self._lock = threading.Lock() - self._lifetime = lifetime - - #Initalize function to get expiration time. - self._expiration_time = lambda: datetime.now() + timedelta(seconds = lifetime) - - # A dictionary to hold connection ID's and when they should be removed from the pool - # Keys are id(connection) and vlaues are expiration time - # Storing the expiration time on the connection object itself might be - # preferable, if possible. - self._expirations = {} - - # Override the _putconn function to put the connection back into the pool even if we are over minconn, and to run the _prune command. - #--------------------------------------------------------------------------- - def _putconn(self, conn, key=None, close=False): - """Put away a connection.""" - if self.closed: - raise PoolError("connection pool is closed") - if key is None: - key = self._rused.get(id(conn)) - - if not key: - raise PoolError("trying to put unkeyed connection") - - if len(self._pool) < self.maxconn and not close: - # Return the connection into a consistent state before putting - # it back into the pool - if not conn.closed: - status = conn.get_transaction_status() - if status == _ext.TRANSACTION_STATUS_UNKNOWN: - # server connection lost - conn.close() - try: - del self._expirations[id(conn)] - except KeyError: - pass - elif status != _ext.TRANSACTION_STATUS_IDLE: - # connection in error or in transaction - conn.rollback() - self._pool.append(conn) - else: - # regular idle connection - self._pool.append(conn) - # If the connection is closed, we just discard it. - else: - try: - del self._expirations[id(conn)] - except KeyError: - pass - else: - conn.close() - #remove this connection from the expiration list - try: - del self._expirations[id(conn)] - except KeyError: - pass #not in the expiration list for some reason, can't remove it. - - # here we check for the presence of key because it can happen that a - # thread tries to put back a connection after a call to close - if not self.closed or key in self._used: - del self._used[key] - del self._rused[id(conn)] - - # remove any expired connections from the pool - self._prune() - - #--------------------------------------------------------------------------- - def getconn(self, key=None): - """Get a free connection and assign it to 'key' if not None.""" + def prune(self): + """Drop all expired connections from the pool.""" self._lock.acquire() try: - conn = self._getconn(key) - #Add expiration time - self._expirations[id(conn)] = self._expiration_time() - return conn + self._prune() finally: self._lock.release() - - #--------------------------------------------------------------------------- - def putconn(self, conn=None, key=None, close=False): - """Put away an unused connection.""" - self._lock.acquire() - try: - self._putconn(conn, key, close) - finally: - self._lock.release() - - #--------------------------------------------------------------------------- - def closeall(self): - """Close all connections (even the one currently in use.)""" - self._lock.acquire() - try: - self._closeall() - finally: - self._lock.release() - - #--------------------------------------------------------------------------- - def _prune(self): - """Remove any expired connections from the connection pool.""" - from datetime import datetime - junk_expirations = [] - for obj_id, exp_time in self._expirations.items(): - if exp_time > datetime.now(): # Not expired, move on. - continue; - - del_idx = None - #find index of connection in _pool. May not be there if connection is in use - for index, conn in enumerate(self._pool): - if id(conn) == obj_id: - conn.close() - junk_expirations.append(obj_id) - del_idx = index - break - else: - # See if this connection is used. If not, we need to remove - # the reference to it. - for conn in self._used.values(): - if id(conn) == obj_id: - break #found it, so just move on. Don't expire the - # connection till we are done with it. - else: - # This connection doesn't exist any more, so get rid - # of the reference to the expiration. - # Can't delete here because we'd be changing the item - # we are itterating over. - junk_expirations.append(obj_id) - - # Delete connection from pool if expired - if del_idx is not None: - del self._pool[del_idx] - - # Remove any junk expirations - for item in junk_expirations: - # Should be safe enough, since it existed in the loop above - del self._expirations[item] - - # Make sure we still have at least minconn connections - # Connections may be available or used - total_conns = len(self._pool) + len(self._used) - if total_conns < self.minconn: - for i in range(self.minconn - total_conns): - self._connect() diff --git a/tests/__init__.py b/tests/__init__.py index cad2a9a4..555b771f 100755 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -48,6 +48,7 @@ from . import test_ipaddress from . import test_lobject from . import test_module from . import test_notify +from . import test_pool from . import test_psycopg2_dbapi20 from . import test_quote from . import test_replication @@ -93,6 +94,7 @@ def test_suite(): suite.addTest(test_lobject.test_suite()) suite.addTest(test_module.test_suite()) suite.addTest(test_notify.test_suite()) + suite.addTest(test_pool.test_suite()) suite.addTest(test_psycopg2_dbapi20.test_suite()) suite.addTest(test_quote.test_suite()) suite.addTest(test_replication.test_suite()) diff --git a/tests/test_pool.py b/tests/test_pool.py index 67de6d72..f22ddbf0 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -20,33 +20,22 @@ # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # License for more details. -from datetime import datetime, timedelta +import unittest + import psycopg2 import psycopg2.pool as psycopg_pool import psycopg2.extensions as _ext -from testutils import (unittest, ConnectingTestCase, skip_before_postgres, - skip_if_no_namedtuple, skip_if_no_getrefcount, slow, skip_if_no_superuser, - skip_if_windows) +from .testutils import ConnectingTestCase -from testconfig import dsn, dbname +from .testconfig import dsn, dbname class PoolTests(ConnectingTestCase): def test_caching_pool_get_conn(self): """Test the call to getconn. Should just return an open connection.""" lifetime = 30 - pool = psycopg_pool.CachingConnectionPool(0, 1, lifetime, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=lifetime) conn = pool.getconn() - expected_expires = datetime.now() + timedelta(seconds = lifetime) - - #Verify we have one entry in the expiration table - self.assertEqual(len(pool._expirations), 1) - actual_expires = pool._expirations[id(conn)] - - # there may be some slight variation between when we created the connection - # and our "expected" expiration. - # Should be negligable, however - self.assertAlmostEqual(expected_expires, actual_expires, delta = timedelta(seconds = 1)) #make sure we got an open connection self.assertFalse(conn.closed) @@ -54,26 +43,24 @@ class PoolTests(ConnectingTestCase): #Try again. We should get an error, since we only allowed one connection self.assertRaises(psycopg2.pool.PoolError, pool.getconn) - # Put the connection back, then get it again. The expiration time should increment - # If this test is consistantly failing, we may need to add a "sleep" to force - # some real time between connections, but as long as the precision of - # datetime is high enough, this should work. All we care is that new_expires - # is greater than the original expiration time + # Put the connection back, the return time should be set. pool.putconn(conn) - conn = pool.getconn() - new_expires = pool._expirations[id(conn)] - self.assertGreater(new_expires, actual_expires) + self.assertIn(id(conn), pool._return_times) + + # Get the connection back. + new_conn = pool.getconn() + self.assertIs(new_conn, conn) def test_caching_pool_prune(self): """Test the prune function to make sure it closes conenctions and removes them from the pool""" - pool = psycopg_pool.CachingConnectionPool(0, 3, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 3, dsn, idle_timeout=30) # Get a connection that we use, so it can't be pruned. sticky_conn = pool.getconn() self.assertFalse(sticky_conn in pool._pool) self.assertTrue(sticky_conn in pool._used.values()) self.assertFalse(sticky_conn.closed) - self.assertTrue(id(sticky_conn) in pool._expirations) + self.assertFalse(id(sticky_conn) in pool._return_times) # create a second connection that is put back into the pool, available to be pruned. conn = pool.getconn() @@ -89,21 +76,19 @@ class PoolTests(ConnectingTestCase): self.assertTrue(conn in pool._pool) self.assertFalse(conn in pool._used.values()) self.assertFalse(conn.closed) - self.assertTrue(id(conn) in pool._expirations) + self.assertTrue(id(conn) in pool._return_times) self.assertTrue(new_conn in pool._pool) self.assertFalse(new_conn in pool._used.values()) self.assertFalse(new_conn.closed) - self.assertTrue(id(new_conn) in pool._expirations) + self.assertTrue(id(new_conn) in pool._return_times) self.assertNotEqual(conn, sticky_conn) self.assertNotEqual(new_conn, conn) #Make the connections expire a minute ago (but not new_con) - old_expire = datetime.now() - timedelta(minutes = 1) - - pool._expirations[id(conn)] = old_expire - pool._expirations[id(sticky_conn)] = old_expire + pool._return_times[id(conn)] -= 60 + pool._return_times[id(sticky_conn)] = pool._return_times[id(conn)] #prune connections pool._prune() @@ -112,34 +97,17 @@ class PoolTests(ConnectingTestCase): # but the used connection isn't self.assertFalse(conn in pool._pool) self.assertTrue(conn.closed) - self.assertFalse(id(conn) in pool._expirations) + self.assertFalse(id(conn) in pool._return_times) self.assertFalse(sticky_conn.closed) - self.assertTrue(id(sticky_conn) in pool._expirations) + self.assertTrue(id(sticky_conn) in pool._return_times) # The un-expired connection should still exist and be open self.assertFalse(new_conn.closed) - self.assertTrue(id(new_conn) in pool._expirations) - - def test_caching_pool_prune_missing_connection(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) - conn = pool.getconn(key = "test") - - self.assertTrue("test" in pool._used) - - #connection got lost somehow. - del pool._used["test"] - - #expire this connection - old_expire = datetime.now() - timedelta(minutes = 1) - - pool._expirations[id(conn)] = old_expire - - # and prune - pool._prune() + self.assertTrue(id(new_conn) in pool._return_times) def test_caching_pool_prune_below_min(self): - pool = psycopg_pool.CachingConnectionPool(1, 1, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(1, 1, dsn, idle_timeout=30) conn = pool.getconn() self.assertFalse(conn in pool._pool) @@ -156,7 +124,7 @@ class PoolTests(ConnectingTestCase): def test_caching_pool_putconn_normal(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30) conn = pool.getconn() self.assertFalse(conn in pool._pool) @@ -164,50 +132,39 @@ class PoolTests(ConnectingTestCase): self.assertTrue(conn in pool._pool) def test_caching_pool_putconn_closecon(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30) conn = pool.getconn() self.assertFalse(conn in pool._pool) pool.putconn(conn, close = True) self.assertFalse(conn in pool._pool) - self.assertFalse(id(conn) in pool._expirations) + self.assertFalse(id(conn) in pool._return_times) - def test_caching_pool_putconn_closecon_noexp(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) - conn = pool.getconn() - self.assertFalse(conn in pool._pool) - - # Something went haywire with the prune, and the expiration information - # for this connection got lost. - del pool._expirations[id(conn)] - self.assertFalse(id(conn) in pool._expirations) - - # Should still work without error - pool.putconn(conn, close = True) - self.assertFalse(conn in pool._pool) - - def test_caching_pool_putconn_expired(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) + def test_caching_pool_getconn_expired(self): + pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30) conn = pool.getconn() #expire the connection - pool._expirations[id(conn)] = datetime.now() - timedelta(minutes = 1) pool.putconn(conn) + pool._return_times[id(conn)] -= 60 #connection should be discarded + new_conn = pool.getconn() + self.assertIsNot(new_conn, conn) self.assertFalse(conn in pool._pool) - self.assertFalse(id(conn) in pool._expirations) + self.assertFalse(id(conn) in pool._return_times) self.assertTrue(conn.closed) def test_caching_pool_putconn_unkeyed(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30) - #Test put with empty key + #Test put with missing key conn = pool.getconn() - self.assertRaises(psycopg_pool.PoolError, pool.putconn, conn, '') + del pool._rused[id(conn)] + self.assertRaises(psycopg_pool.PoolError, pool.putconn, conn) def test_caching_pool_putconn_errorState(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30) conn = pool.getconn() #Get connection into transaction state @@ -228,54 +185,26 @@ class PoolTests(ConnectingTestCase): self.assertTrue(conn in pool._pool) def test_caching_pool_putconn_closed(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30) conn = pool.getconn() - #Open connection with expiration + # The connection should be open and shouldn't have a return time. self.assertFalse(conn.closed) - self.assertTrue(id(conn) in pool._expirations) + self.assertFalse(id(conn) in pool._return_times) conn.close() - # Now should be closed, but still have expiration entry + # Now should be closed self.assertTrue(conn.closed) - self.assertTrue(id(conn) in pool._expirations) pool.putconn(conn) - # we should not have an expiration any more - self.assertFalse(id(conn) in pool._expirations) - - # and the connection should have been discarded - self.assertFalse(conn in pool._pool) - - def test_caching_pool_putconn_closed_noexp(self): - pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn) - conn = pool.getconn() - - #Open connection with expiration - self.assertFalse(conn.closed) - self.assertTrue(id(conn) in pool._expirations) - - conn.close() - - # Now should be closed, but still have expiration entry - self.assertTrue(conn.closed) - self.assertTrue(id(conn) in pool._expirations) - - # Delete the expiration entry to simulate confusion - del pool._expirations[id(conn)] - - # we should not have an expiration any more - self.assertFalse(id(conn) in pool._expirations) - - pool.putconn(conn) - - # and the connection should have been discarded, without error + # the connection should have been discarded self.assertFalse(conn in pool._pool) + self.assertFalse(id(conn) in pool._return_times) def test_caching_pool_caching(self): - pool = psycopg_pool.CachingConnectionPool(0, 10, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 10, dsn, idle_timeout=30) # Get a connection to use to check the number of connections check_conn = pool.getconn() @@ -293,8 +222,6 @@ class PoolTests(ConnectingTestCase): conn2 = pool.getconn() conn3 = pool.getconn() - self.assertEqual(len(pool._expirations), 3) - self.assertNotEqual(conn2, conn3) # Verify that we have the expected number of connections to the DB server now @@ -324,13 +251,12 @@ class PoolTests(ConnectingTestCase): self.assertEqual(total_cons_after_get, total_cons) def test_caching_pool_closeall(self): - pool = psycopg_pool.CachingConnectionPool(0, 10, 30, dsn) + pool = psycopg_pool.SimpleConnectionPool(0, 10, dsn, idle_timeout=30) conn1 = pool.getconn() conn2 = pool.getconn() pool.putconn(conn2) self.assertEqual(len(pool._pool), 1) #1 in use, 1 put back - self.assertEqual(len(pool._expirations), 2) # We have two expirations for two connections self.assertEqual(len(pool._used), 1) # and we have one used connection # Both connections should be open at this point @@ -348,8 +274,16 @@ class PoolTests(ConnectingTestCase): # self.assertEqual(len(pool._used), 0) # self.assertEqual(len(pool._pool), 0) - # To maintain consistancy with existing code, closeall doesn't mess with the _expirations dict either - # self.assertEqual(len(pool._expirations), 0) + # To maintain consistancy with existing code, closeall doesn't mess with the _return_times dict either + # self.assertEqual(len(pool._return_times), 0) #We should get an error if we try to put conn1 back now self.assertRaises(psycopg2.pool.PoolError, pool.putconn, conn1) + + +def test_suite(): + return unittest.TestLoader().loadTestsFromName(__name__) + + +if __name__ == "__main__": + unittest.main()