This commit is contained in:
Christian Zagrodnick 2013-03-19 07:52:24 -07:00
commit fde232831b
4 changed files with 82 additions and 34 deletions

5
NEWS
View File

@ -1,3 +1,8 @@
What's new in psycopg XXXXX
---------------------------
- Fixed multi-thread connection initialization for ZPsycopgDA.
What's new in psycopg 2.4.6 What's new in psycopg 2.4.6
--------------------------- ---------------------------

View File

@ -47,11 +47,13 @@ class DB(TM, dbi_db.DB):
self.calls = 0 self.calls = 0
self.make_mappings() self.make_mappings()
def getconn(self, init=True): def getconn(self):
# if init is False we are trying to get hold on an already existing # if init is False we are trying to get hold on an already existing
# connection, so we avoid to (re)initialize it risking errors. # connection, so we avoid to (re)initialize it risking errors.
conn = pool.getconn(self.dsn) conn = pool.getconn(self.dsn, init=self.init_conn)
if init: return conn
def init_conn(self, conn):
# use set_session where available as in these versions # use set_session where available as in these versions
# set_isolation_level generates an extra query. # set_isolation_level generates an extra query.
if psycopg2.__version__ >= '2.4.2': if psycopg2.__version__ >= '2.4.2':
@ -61,22 +63,18 @@ class DB(TM, dbi_db.DB):
conn.set_client_encoding(self.encoding) conn.set_client_encoding(self.encoding)
for tc in self.typecasts: for tc in self.typecasts:
register_type(tc, conn) register_type(tc, conn)
return conn
def putconn(self, close=False): def putconn(self, close=False):
try: conn = pool.getconn(self.dsn, create_pool=False, init=self.init_conn)
conn = pool.getconn(self.dsn, False)
except AttributeError:
pass
pool.putconn(self.dsn, conn, close) pool.putconn(self.dsn, conn, close)
def getcursor(self): def getcursor(self):
conn = self.getconn(False) conn = self.getconn()
return conn.cursor() return conn.cursor()
def _finish(self, *ignored): def _finish(self, *ignored):
try: try:
conn = self.getconn(False) conn = self.getconn()
conn.commit() conn.commit()
self.putconn() self.putconn()
except AttributeError: except AttributeError:
@ -84,7 +82,7 @@ class DB(TM, dbi_db.DB):
def _abort(self, *ignored): def _abort(self, *ignored):
try: try:
conn = self.getconn(False) conn = self.getconn()
conn.rollback() conn.rollback()
self.putconn() self.putconn()
except AttributeError: except AttributeError:

View File

@ -26,7 +26,7 @@ from psycopg2.pool import PoolError
class AbstractConnectionPool(object): class AbstractConnectionPool(object):
"""Generic key-based pooling code.""" """Generic key-based pooling code."""
def __init__(self, minconn, maxconn, *args, **kwargs): def __init__(self, minconn, maxconn, init, *args, **kwargs):
"""Initialize the connection pool. """Initialize the connection pool.
New 'minconn' connections are created immediately calling 'connfunc' New 'minconn' connections are created immediately calling 'connfunc'
@ -35,6 +35,7 @@ class AbstractConnectionPool(object):
""" """
self.minconn = minconn self.minconn = minconn
self.maxconn = maxconn self.maxconn = maxconn
self.init = init
self.closed = False self.closed = False
self._args = args self._args = args
@ -56,6 +57,8 @@ class AbstractConnectionPool(object):
self._rused[id(conn)] = key self._rused[id(conn)] = key
else: else:
self._pool.append(conn) self._pool.append(conn)
if self.init:
self.init(conn)
return conn return conn
def _getkey(self): def _getkey(self):
@ -125,11 +128,11 @@ class PersistentConnectionPool(AbstractConnectionPool):
single connection from the pool. single connection from the pool.
""" """
def __init__(self, minconn, maxconn, *args, **kwargs): def __init__(self, minconn, maxconn, init, *args, **kwargs):
"""Initialize the threading lock.""" """Initialize the threading lock."""
import threading import threading
AbstractConnectionPool.__init__( AbstractConnectionPool.__init__(
self, minconn, maxconn, *args, **kwargs) self, minconn, maxconn, init, *args, **kwargs)
self._lock = threading.Lock() self._lock = threading.Lock()
# we we'll need the thread module, to determine thread ids, so we # we we'll need the thread module, to determine thread ids, so we
@ -168,12 +171,12 @@ class PersistentConnectionPool(AbstractConnectionPool):
_connections_pool = {} _connections_pool = {}
_connections_lock = threading.Lock() _connections_lock = threading.Lock()
def getpool(dsn, create=True): def getpool(dsn, create=True, init=None):
_connections_lock.acquire() _connections_lock.acquire()
try: try:
if not _connections_pool.has_key(dsn) and create: if not _connections_pool.has_key(dsn) and create:
_connections_pool[dsn] = \ _connections_pool[dsn] = \
PersistentConnectionPool(4, 200, dsn) PersistentConnectionPool(4, 200, init, dsn)
finally: finally:
_connections_lock.release() _connections_lock.release()
return _connections_pool[dsn] return _connections_pool[dsn]
@ -186,8 +189,8 @@ def flushpool(dsn):
finally: finally:
_connections_lock.release() _connections_lock.release()
def getconn(dsn, create=True): def getconn(dsn, create_pool=True, init=None):
return getpool(dsn, create=create).getconn() return getpool(dsn, create=create_pool, init=init).getconn()
def putconn(dsn, conn, close=False): def putconn(dsn, conn, close=False):
getpool(dsn).putconn(conn, close=close) getpool(dsn).putconn(conn, close=close)

42
ZPsycopgDA/test_da.py Normal file
View File

@ -0,0 +1,42 @@
# zopectl run script to test the DA/threading behavior
#
# Usage: bin/zopectl run test_da.py "dbname=xxx"
#
from Products.ZPsycopgDA.DA import ZDATETIME
from Products.ZPsycopgDA.db import DB
import sys
import threading
dsn = sys.argv[1]
typecasts = [ZDATETIME]
def DA_connect():
db = DB(dsn, tilevel=2, typecasts=typecasts)
db.open()
return db
def assert_casts(conn, name):
connection = conn.getcursor().connection
if (connection.string_types ==
{1114: ZDATETIME, 1184: ZDATETIME}):
print '%s pass\n' % name
else:
print '%s fail (%s)\n' % (name, connection.string_types)
def test_connect(name):
assert_casts(conn1, name)
conn1 = DA_connect()
t1 = threading.Thread(target=test_connect, args=('t1',))
t1.start()
t2 = threading.Thread(target=test_connect, args=('t2',))
t2.start()
t1.join()
t2.join()