This commit is contained in:
Christian Zagrodnick 2013-03-19 07:52:24 -07:00
commit fde232831b
4 changed files with 82 additions and 34 deletions

5
NEWS
View File

@ -1,3 +1,8 @@
What's new in psycopg XXXXX
---------------------------
- Fixed multi-thread connection initialization for ZPsycopgDA.
What's new in psycopg 2.4.6
---------------------------

View File

@ -47,11 +47,13 @@ class DB(TM, dbi_db.DB):
self.calls = 0
self.make_mappings()
def getconn(self, init=True):
def getconn(self):
# if init is False we are trying to get hold on an already existing
# connection, so we avoid to (re)initialize it risking errors.
conn = pool.getconn(self.dsn)
if init:
conn = pool.getconn(self.dsn, init=self.init_conn)
return conn
def init_conn(self, conn):
# use set_session where available as in these versions
# set_isolation_level generates an extra query.
if psycopg2.__version__ >= '2.4.2':
@ -61,22 +63,18 @@ class DB(TM, dbi_db.DB):
conn.set_client_encoding(self.encoding)
for tc in self.typecasts:
register_type(tc, conn)
return conn
def putconn(self, close=False):
try:
conn = pool.getconn(self.dsn, False)
except AttributeError:
pass
conn = pool.getconn(self.dsn, create_pool=False, init=self.init_conn)
pool.putconn(self.dsn, conn, close)
def getcursor(self):
conn = self.getconn(False)
conn = self.getconn()
return conn.cursor()
def _finish(self, *ignored):
try:
conn = self.getconn(False)
conn = self.getconn()
conn.commit()
self.putconn()
except AttributeError:
@ -84,7 +82,7 @@ class DB(TM, dbi_db.DB):
def _abort(self, *ignored):
try:
conn = self.getconn(False)
conn = self.getconn()
conn.rollback()
self.putconn()
except AttributeError:

View File

@ -26,7 +26,7 @@ from psycopg2.pool import PoolError
class AbstractConnectionPool(object):
"""Generic key-based pooling code."""
def __init__(self, minconn, maxconn, *args, **kwargs):
def __init__(self, minconn, maxconn, init, *args, **kwargs):
"""Initialize the connection pool.
New 'minconn' connections are created immediately calling 'connfunc'
@ -35,6 +35,7 @@ class AbstractConnectionPool(object):
"""
self.minconn = minconn
self.maxconn = maxconn
self.init = init
self.closed = False
self._args = args
@ -56,6 +57,8 @@ class AbstractConnectionPool(object):
self._rused[id(conn)] = key
else:
self._pool.append(conn)
if self.init:
self.init(conn)
return conn
def _getkey(self):
@ -125,11 +128,11 @@ class PersistentConnectionPool(AbstractConnectionPool):
single connection from the pool.
"""
def __init__(self, minconn, maxconn, *args, **kwargs):
def __init__(self, minconn, maxconn, init, *args, **kwargs):
"""Initialize the threading lock."""
import threading
AbstractConnectionPool.__init__(
self, minconn, maxconn, *args, **kwargs)
self, minconn, maxconn, init, *args, **kwargs)
self._lock = threading.Lock()
# we we'll need the thread module, to determine thread ids, so we
@ -168,12 +171,12 @@ class PersistentConnectionPool(AbstractConnectionPool):
_connections_pool = {}
_connections_lock = threading.Lock()
def getpool(dsn, create=True):
def getpool(dsn, create=True, init=None):
_connections_lock.acquire()
try:
if not _connections_pool.has_key(dsn) and create:
_connections_pool[dsn] = \
PersistentConnectionPool(4, 200, dsn)
PersistentConnectionPool(4, 200, init, dsn)
finally:
_connections_lock.release()
return _connections_pool[dsn]
@ -186,8 +189,8 @@ def flushpool(dsn):
finally:
_connections_lock.release()
def getconn(dsn, create=True):
return getpool(dsn, create=create).getconn()
def getconn(dsn, create_pool=True, init=None):
return getpool(dsn, create=create_pool, init=init).getconn()
def putconn(dsn, conn, close=False):
getpool(dsn).putconn(conn, close=close)

42
ZPsycopgDA/test_da.py Normal file
View File

@ -0,0 +1,42 @@
# zopectl run script to test the DA/threading behavior
#
# Usage: bin/zopectl run test_da.py "dbname=xxx"
#
from Products.ZPsycopgDA.DA import ZDATETIME
from Products.ZPsycopgDA.db import DB
import sys
import threading
dsn = sys.argv[1]
typecasts = [ZDATETIME]
def DA_connect():
db = DB(dsn, tilevel=2, typecasts=typecasts)
db.open()
return db
def assert_casts(conn, name):
connection = conn.getcursor().connection
if (connection.string_types ==
{1114: ZDATETIME, 1184: ZDATETIME}):
print '%s pass\n' % name
else:
print '%s fail (%s)\n' % (name, connection.string_types)
def test_connect(name):
assert_casts(conn1, name)
conn1 = DA_connect()
t1 = threading.Thread(target=test_connect, args=('t1',))
t1.start()
t2 = threading.Thread(target=test_connect, args=('t2',))
t2.start()
t1.join()
t2.join()