integrate connection caching into the base pool class

This commit also fixes two preexisting issues:

- the status of a connection wasn't checked in `getconn`, so that method could return a broken connection
- the `putconn` method didn't reopen new connections if the total had dropped below `minconn`
This commit is contained in:
Changaco 2019-09-23 10:55:21 +02:00
parent 456f6b660b
commit 6e2ddd26c5
4 changed files with 129 additions and 285 deletions

View File

@ -12,7 +12,7 @@ Creating new PostgreSQL connections can be an expensive operation. This
module offers a few pure Python classes implementing simple connection pooling
directly in the client application.
.. class:: AbstractConnectionPool(minconn, maxconn, \*args, \*\*kwargs)
.. class:: AbstractConnectionPool(minconn, maxconn, \*args, idle_timeout=0, \*\*kwargs)
Base class implementing generic key-based pooling code.
@ -20,6 +20,15 @@ directly in the client application.
a maximum of about *maxconn* connections. *\*args* and *\*\*kwargs* are
passed to the `~psycopg2.connect()` function.
Connections are kept in the pool for at most *idle_timeout* seconds. There
are two special values: zero means that connections are always immediately
closed upon their return to the pool; `None` means that connections are kept
indefinitely (leaving the server in charge of closing idle connections). The
current default value is zero because it replicates the behavior of previous
versions, however the default value may be changed in a future release.
.. versionadded:: 2.9 the *idle_timeout* argument.
The following methods are expected to be implemented by subclasses:
.. method:: getconn(key=None)
@ -44,6 +53,14 @@ directly in the client application.
Note that all the connections are closed, including ones
eventually in use by the application.
.. method:: prune
Drop all expired connections from the pool.
You can call this method periodically to clean up the pool.
.. versionadded:: 2.9
The following classes are `AbstractConnectionPool` subclasses ready to
be used.
@ -58,7 +75,3 @@ be used.
.. autoclass:: ThreadedConnectionPool
.. note:: This pool class can be safely used in multi-threaded applications.
.. autoclass:: CachingConnectionPool
.. note:: Expired connections are cleaned up on any call to putconn.

View File

@ -24,6 +24,12 @@ This module implements thread-safe (and not) connection pools.
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
try:
from time import process_time
except ImportError:
# For Python < 3.3
from time import clock as process_time
import psycopg2
from psycopg2 import extensions as _ext
@ -44,6 +50,7 @@ class AbstractConnectionPool(object):
"""
self.minconn = int(minconn)
self.maxconn = int(maxconn)
self.idle_timeout = kwargs.pop('idle_timeout', 0)
self.closed = False
self._args = args
@ -53,6 +60,7 @@ class AbstractConnectionPool(object):
self._used = {}
self._rused = {} # id(conn) -> key map
self._keys = 0
self._return_times = {}
for i in range(self.minconn):
self._connect()
@ -64,6 +72,7 @@ class AbstractConnectionPool(object):
self._used[key] = conn
self._rused[id(conn)] = key
else:
self._return_times[id(conn)] = process_time()
self._pool.append(conn)
return conn
@ -82,14 +91,26 @@ class AbstractConnectionPool(object):
if key in self._used:
return self._used[key]
if self._pool:
self._used[key] = conn = self._pool.pop()
self._rused[id(conn)] = key
return conn
else:
if len(self._used) == self.maxconn:
raise PoolError("connection pool exhausted")
return self._connect(key)
while True:
try:
conn = self._pool.pop()
except IndexError:
if len(self._used) >= self.maxconn:
raise PoolError("connection pool exhausted")
conn = self._connect(key)
else:
idle_since = self._return_times.pop(id(conn), 0)
close = (
conn.info.transaction_status != _ext.TRANSACTION_STATUS_IDLE or
self.idle_timeout and idle_since < (process_time() - self.idle_timeout)
)
if close:
conn.close()
continue
break
self._used[key] = conn
self._rused[id(conn)] = key
return conn
def _putconn(self, conn, key=None, close=False):
"""Put away a connection."""
@ -101,7 +122,9 @@ class AbstractConnectionPool(object):
if key is None:
raise PoolError("trying to put unkeyed connection")
if len(self._pool) < self.minconn and not close:
if close or self.idle_timeout == 0 and len(self._pool) >= self.minconn:
conn.close()
else:
# Return the connection into a consistent state before putting
# it back into the pool
if not conn.closed:
@ -112,13 +135,13 @@ class AbstractConnectionPool(object):
elif status != _ext.TRANSACTION_STATUS_IDLE:
# connection in error or in transaction
conn.rollback()
self._return_times[id(conn)] = process_time()
self._pool.append(conn)
else:
# regular idle connection
self._return_times[id(conn)] = process_time()
self._pool.append(conn)
# If the connection is closed, we just discard it.
else:
conn.close()
# here we check for the presence of key because it can happen that a
# thread tries to put back a connection after a call to close
@ -126,6 +149,10 @@ class AbstractConnectionPool(object):
del self._used[key]
del self._rused[id(conn)]
# Open new connections if we've dropped below minconn.
while (len(self._pool) + len(self._used)) < self.minconn:
self._connect()
def _closeall(self):
"""Close all connections.
@ -142,6 +169,20 @@ class AbstractConnectionPool(object):
pass
self.closed = True
def _prune(self):
"""Drop all expired connections from the pool."""
if self.idle_timeout is None:
return
threshold = process_time() - self.idle_timeout
for conn in list(self._pool):
if self._return_times.get(id(conn), 0) < threshold:
try:
self._pool.remove(conn)
except ValueError:
continue
self._return_times.pop(id(conn), None)
conn.close()
class SimpleConnectionPool(AbstractConnectionPool):
"""A connection pool that can't be shared across different threads."""
@ -149,6 +190,7 @@ class SimpleConnectionPool(AbstractConnectionPool):
getconn = AbstractConnectionPool._getconn
putconn = AbstractConnectionPool._putconn
closeall = AbstractConnectionPool._closeall
prune = AbstractConnectionPool._prune
class ThreadedConnectionPool(AbstractConnectionPool):
@ -185,157 +227,10 @@ class ThreadedConnectionPool(AbstractConnectionPool):
finally:
self._lock.release()
class CachingConnectionPool(AbstractConnectionPool):
"""A connection pool that works with the threading module and caches connections"""
#---------------------------------------------------------------------------
def __init__(self, minconn, maxconn, lifetime = 3600, *args, **kwargs):
"""Initialize the threading lock."""
import threading
from datetime import datetime, timedelta
AbstractConnectionPool.__init__(
self, minconn, maxconn, *args, **kwargs)
self._lock = threading.Lock()
self._lifetime = lifetime
#Initalize function to get expiration time.
self._expiration_time = lambda: datetime.now() + timedelta(seconds = lifetime)
# A dictionary to hold connection ID's and when they should be removed from the pool
# Keys are id(connection) and vlaues are expiration time
# Storing the expiration time on the connection object itself might be
# preferable, if possible.
self._expirations = {}
# Override the _putconn function to put the connection back into the pool even if we are over minconn, and to run the _prune command.
#---------------------------------------------------------------------------
def _putconn(self, conn, key=None, close=False):
"""Put away a connection."""
if self.closed:
raise PoolError("connection pool is closed")
if key is None:
key = self._rused.get(id(conn))
if not key:
raise PoolError("trying to put unkeyed connection")
if len(self._pool) < self.maxconn and not close:
# Return the connection into a consistent state before putting
# it back into the pool
if not conn.closed:
status = conn.get_transaction_status()
if status == _ext.TRANSACTION_STATUS_UNKNOWN:
# server connection lost
conn.close()
try:
del self._expirations[id(conn)]
except KeyError:
pass
elif status != _ext.TRANSACTION_STATUS_IDLE:
# connection in error or in transaction
conn.rollback()
self._pool.append(conn)
else:
# regular idle connection
self._pool.append(conn)
# If the connection is closed, we just discard it.
else:
try:
del self._expirations[id(conn)]
except KeyError:
pass
else:
conn.close()
#remove this connection from the expiration list
try:
del self._expirations[id(conn)]
except KeyError:
pass #not in the expiration list for some reason, can't remove it.
# here we check for the presence of key because it can happen that a
# thread tries to put back a connection after a call to close
if not self.closed or key in self._used:
del self._used[key]
del self._rused[id(conn)]
# remove any expired connections from the pool
self._prune()
#---------------------------------------------------------------------------
def getconn(self, key=None):
"""Get a free connection and assign it to 'key' if not None."""
def prune(self):
"""Drop all expired connections from the pool."""
self._lock.acquire()
try:
conn = self._getconn(key)
#Add expiration time
self._expirations[id(conn)] = self._expiration_time()
return conn
self._prune()
finally:
self._lock.release()
#---------------------------------------------------------------------------
def putconn(self, conn=None, key=None, close=False):
"""Put away an unused connection."""
self._lock.acquire()
try:
self._putconn(conn, key, close)
finally:
self._lock.release()
#---------------------------------------------------------------------------
def closeall(self):
"""Close all connections (even the one currently in use.)"""
self._lock.acquire()
try:
self._closeall()
finally:
self._lock.release()
#---------------------------------------------------------------------------
def _prune(self):
"""Remove any expired connections from the connection pool."""
from datetime import datetime
junk_expirations = []
for obj_id, exp_time in self._expirations.items():
if exp_time > datetime.now(): # Not expired, move on.
continue;
del_idx = None
#find index of connection in _pool. May not be there if connection is in use
for index, conn in enumerate(self._pool):
if id(conn) == obj_id:
conn.close()
junk_expirations.append(obj_id)
del_idx = index
break
else:
# See if this connection is used. If not, we need to remove
# the reference to it.
for conn in self._used.values():
if id(conn) == obj_id:
break #found it, so just move on. Don't expire the
# connection till we are done with it.
else:
# This connection doesn't exist any more, so get rid
# of the reference to the expiration.
# Can't delete here because we'd be changing the item
# we are itterating over.
junk_expirations.append(obj_id)
# Delete connection from pool if expired
if del_idx is not None:
del self._pool[del_idx]
# Remove any junk expirations
for item in junk_expirations:
# Should be safe enough, since it existed in the loop above
del self._expirations[item]
# Make sure we still have at least minconn connections
# Connections may be available or used
total_conns = len(self._pool) + len(self._used)
if total_conns < self.minconn:
for i in range(self.minconn - total_conns):
self._connect()

View File

@ -48,6 +48,7 @@ from . import test_ipaddress
from . import test_lobject
from . import test_module
from . import test_notify
from . import test_pool
from . import test_psycopg2_dbapi20
from . import test_quote
from . import test_replication
@ -93,6 +94,7 @@ def test_suite():
suite.addTest(test_lobject.test_suite())
suite.addTest(test_module.test_suite())
suite.addTest(test_notify.test_suite())
suite.addTest(test_pool.test_suite())
suite.addTest(test_psycopg2_dbapi20.test_suite())
suite.addTest(test_quote.test_suite())
suite.addTest(test_replication.test_suite())

View File

@ -20,33 +20,22 @@
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
from datetime import datetime, timedelta
import unittest
import psycopg2
import psycopg2.pool as psycopg_pool
import psycopg2.extensions as _ext
from testutils import (unittest, ConnectingTestCase, skip_before_postgres,
skip_if_no_namedtuple, skip_if_no_getrefcount, slow, skip_if_no_superuser,
skip_if_windows)
from .testutils import ConnectingTestCase
from testconfig import dsn, dbname
from .testconfig import dsn, dbname
class PoolTests(ConnectingTestCase):
def test_caching_pool_get_conn(self):
"""Test the call to getconn. Should just return an open connection."""
lifetime = 30
pool = psycopg_pool.CachingConnectionPool(0, 1, lifetime, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=lifetime)
conn = pool.getconn()
expected_expires = datetime.now() + timedelta(seconds = lifetime)
#Verify we have one entry in the expiration table
self.assertEqual(len(pool._expirations), 1)
actual_expires = pool._expirations[id(conn)]
# there may be some slight variation between when we created the connection
# and our "expected" expiration.
# Should be negligable, however
self.assertAlmostEqual(expected_expires, actual_expires, delta = timedelta(seconds = 1))
#make sure we got an open connection
self.assertFalse(conn.closed)
@ -54,26 +43,24 @@ class PoolTests(ConnectingTestCase):
#Try again. We should get an error, since we only allowed one connection
self.assertRaises(psycopg2.pool.PoolError, pool.getconn)
# Put the connection back, then get it again. The expiration time should increment
# If this test is consistantly failing, we may need to add a "sleep" to force
# some real time between connections, but as long as the precision of
# datetime is high enough, this should work. All we care is that new_expires
# is greater than the original expiration time
# Put the connection back, the return time should be set.
pool.putconn(conn)
conn = pool.getconn()
new_expires = pool._expirations[id(conn)]
self.assertGreater(new_expires, actual_expires)
self.assertIn(id(conn), pool._return_times)
# Get the connection back.
new_conn = pool.getconn()
self.assertIs(new_conn, conn)
def test_caching_pool_prune(self):
"""Test the prune function to make sure it closes conenctions and removes them from the pool"""
pool = psycopg_pool.CachingConnectionPool(0, 3, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 3, dsn, idle_timeout=30)
# Get a connection that we use, so it can't be pruned.
sticky_conn = pool.getconn()
self.assertFalse(sticky_conn in pool._pool)
self.assertTrue(sticky_conn in pool._used.values())
self.assertFalse(sticky_conn.closed)
self.assertTrue(id(sticky_conn) in pool._expirations)
self.assertFalse(id(sticky_conn) in pool._return_times)
# create a second connection that is put back into the pool, available to be pruned.
conn = pool.getconn()
@ -89,21 +76,19 @@ class PoolTests(ConnectingTestCase):
self.assertTrue(conn in pool._pool)
self.assertFalse(conn in pool._used.values())
self.assertFalse(conn.closed)
self.assertTrue(id(conn) in pool._expirations)
self.assertTrue(id(conn) in pool._return_times)
self.assertTrue(new_conn in pool._pool)
self.assertFalse(new_conn in pool._used.values())
self.assertFalse(new_conn.closed)
self.assertTrue(id(new_conn) in pool._expirations)
self.assertTrue(id(new_conn) in pool._return_times)
self.assertNotEqual(conn, sticky_conn)
self.assertNotEqual(new_conn, conn)
#Make the connections expire a minute ago (but not new_con)
old_expire = datetime.now() - timedelta(minutes = 1)
pool._expirations[id(conn)] = old_expire
pool._expirations[id(sticky_conn)] = old_expire
pool._return_times[id(conn)] -= 60
pool._return_times[id(sticky_conn)] = pool._return_times[id(conn)]
#prune connections
pool._prune()
@ -112,34 +97,17 @@ class PoolTests(ConnectingTestCase):
# but the used connection isn't
self.assertFalse(conn in pool._pool)
self.assertTrue(conn.closed)
self.assertFalse(id(conn) in pool._expirations)
self.assertFalse(id(conn) in pool._return_times)
self.assertFalse(sticky_conn.closed)
self.assertTrue(id(sticky_conn) in pool._expirations)
self.assertTrue(id(sticky_conn) in pool._return_times)
# The un-expired connection should still exist and be open
self.assertFalse(new_conn.closed)
self.assertTrue(id(new_conn) in pool._expirations)
def test_caching_pool_prune_missing_connection(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
conn = pool.getconn(key = "test")
self.assertTrue("test" in pool._used)
#connection got lost somehow.
del pool._used["test"]
#expire this connection
old_expire = datetime.now() - timedelta(minutes = 1)
pool._expirations[id(conn)] = old_expire
# and prune
pool._prune()
self.assertTrue(id(new_conn) in pool._return_times)
def test_caching_pool_prune_below_min(self):
pool = psycopg_pool.CachingConnectionPool(1, 1, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(1, 1, dsn, idle_timeout=30)
conn = pool.getconn()
self.assertFalse(conn in pool._pool)
@ -156,7 +124,7 @@ class PoolTests(ConnectingTestCase):
def test_caching_pool_putconn_normal(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30)
conn = pool.getconn()
self.assertFalse(conn in pool._pool)
@ -164,50 +132,39 @@ class PoolTests(ConnectingTestCase):
self.assertTrue(conn in pool._pool)
def test_caching_pool_putconn_closecon(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30)
conn = pool.getconn()
self.assertFalse(conn in pool._pool)
pool.putconn(conn, close = True)
self.assertFalse(conn in pool._pool)
self.assertFalse(id(conn) in pool._expirations)
self.assertFalse(id(conn) in pool._return_times)
def test_caching_pool_putconn_closecon_noexp(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
conn = pool.getconn()
self.assertFalse(conn in pool._pool)
# Something went haywire with the prune, and the expiration information
# for this connection got lost.
del pool._expirations[id(conn)]
self.assertFalse(id(conn) in pool._expirations)
# Should still work without error
pool.putconn(conn, close = True)
self.assertFalse(conn in pool._pool)
def test_caching_pool_putconn_expired(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
def test_caching_pool_getconn_expired(self):
pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30)
conn = pool.getconn()
#expire the connection
pool._expirations[id(conn)] = datetime.now() - timedelta(minutes = 1)
pool.putconn(conn)
pool._return_times[id(conn)] -= 60
#connection should be discarded
new_conn = pool.getconn()
self.assertIsNot(new_conn, conn)
self.assertFalse(conn in pool._pool)
self.assertFalse(id(conn) in pool._expirations)
self.assertFalse(id(conn) in pool._return_times)
self.assertTrue(conn.closed)
def test_caching_pool_putconn_unkeyed(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30)
#Test put with empty key
#Test put with missing key
conn = pool.getconn()
self.assertRaises(psycopg_pool.PoolError, pool.putconn, conn, '')
del pool._rused[id(conn)]
self.assertRaises(psycopg_pool.PoolError, pool.putconn, conn)
def test_caching_pool_putconn_errorState(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30)
conn = pool.getconn()
#Get connection into transaction state
@ -228,54 +185,26 @@ class PoolTests(ConnectingTestCase):
self.assertTrue(conn in pool._pool)
def test_caching_pool_putconn_closed(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 1, dsn, idle_timeout=30)
conn = pool.getconn()
#Open connection with expiration
# The connection should be open and shouldn't have a return time.
self.assertFalse(conn.closed)
self.assertTrue(id(conn) in pool._expirations)
self.assertFalse(id(conn) in pool._return_times)
conn.close()
# Now should be closed, but still have expiration entry
# Now should be closed
self.assertTrue(conn.closed)
self.assertTrue(id(conn) in pool._expirations)
pool.putconn(conn)
# we should not have an expiration any more
self.assertFalse(id(conn) in pool._expirations)
# and the connection should have been discarded
self.assertFalse(conn in pool._pool)
def test_caching_pool_putconn_closed_noexp(self):
pool = psycopg_pool.CachingConnectionPool(0, 1, 30, dsn)
conn = pool.getconn()
#Open connection with expiration
self.assertFalse(conn.closed)
self.assertTrue(id(conn) in pool._expirations)
conn.close()
# Now should be closed, but still have expiration entry
self.assertTrue(conn.closed)
self.assertTrue(id(conn) in pool._expirations)
# Delete the expiration entry to simulate confusion
del pool._expirations[id(conn)]
# we should not have an expiration any more
self.assertFalse(id(conn) in pool._expirations)
pool.putconn(conn)
# and the connection should have been discarded, without error
# the connection should have been discarded
self.assertFalse(conn in pool._pool)
self.assertFalse(id(conn) in pool._return_times)
def test_caching_pool_caching(self):
pool = psycopg_pool.CachingConnectionPool(0, 10, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 10, dsn, idle_timeout=30)
# Get a connection to use to check the number of connections
check_conn = pool.getconn()
@ -293,8 +222,6 @@ class PoolTests(ConnectingTestCase):
conn2 = pool.getconn()
conn3 = pool.getconn()
self.assertEqual(len(pool._expirations), 3)
self.assertNotEqual(conn2, conn3)
# Verify that we have the expected number of connections to the DB server now
@ -324,13 +251,12 @@ class PoolTests(ConnectingTestCase):
self.assertEqual(total_cons_after_get, total_cons)
def test_caching_pool_closeall(self):
pool = psycopg_pool.CachingConnectionPool(0, 10, 30, dsn)
pool = psycopg_pool.SimpleConnectionPool(0, 10, dsn, idle_timeout=30)
conn1 = pool.getconn()
conn2 = pool.getconn()
pool.putconn(conn2)
self.assertEqual(len(pool._pool), 1) #1 in use, 1 put back
self.assertEqual(len(pool._expirations), 2) # We have two expirations for two connections
self.assertEqual(len(pool._used), 1) # and we have one used connection
# Both connections should be open at this point
@ -348,8 +274,16 @@ class PoolTests(ConnectingTestCase):
# self.assertEqual(len(pool._used), 0)
# self.assertEqual(len(pool._pool), 0)
# To maintain consistancy with existing code, closeall doesn't mess with the _expirations dict either
# self.assertEqual(len(pool._expirations), 0)
# To maintain consistancy with existing code, closeall doesn't mess with the _return_times dict either
# self.assertEqual(len(pool._return_times), 0)
#We should get an error if we try to put conn1 back now
self.assertRaises(psycopg2.pool.PoolError, pool.putconn, conn1)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()