pep8 enforced for all files under lib, except for E501 (I cut a single line to < 120 characters), and E701 multiple statements on one line (colon)

This commit is contained in:
Mads Jensen 2015-12-26 21:48:05 +01:00
parent 452fd56e04
commit 6aff2af0d8
9 changed files with 131 additions and 102 deletions

View File

@ -82,13 +82,14 @@ else:
import re import re
def _param_escape(s,
re_escape=re.compile(r"([\\'])"), def _param_escape(s, re_escape=re.compile(r"([\\'])"),
re_space=re.compile(r'\s')): re_space=re.compile(r'\s')):
""" """
Apply the escaping rule required by PQconnectdb Apply the escaping rule required by PQconnectdb
""" """
if not s: return "''" if not s:
return "''"
s = re_escape.sub(r'\\\1', s) s = re_escape.sub(r'\\\1', s)
if re_space.search(s): if re_space.search(s):
@ -99,9 +100,8 @@ def _param_escape(s,
del re del re
def connect(dsn=None, def connect(dsn=None, database=None, user=None, password=None, host=None, port=None,
database=None, user=None, password=None, host=None, port=None, connection_factory=None, cursor_factory=None, async=False, **kwargs):
connection_factory=None, cursor_factory=None, async=False, **kwargs):
""" """
Create a new database connection. Create a new database connection.
@ -152,14 +152,14 @@ def connect(dsn=None,
if dsn is not None and items: if dsn is not None and items:
raise TypeError( raise TypeError(
"'%s' is an invalid keyword argument when the dsn is specified" "'%s' is an invalid keyword argument when the dsn is specified"
% items[0][0]) % items[0][0])
if dsn is None: if dsn is None:
if not items: if not items:
raise TypeError('missing dsn and no parameters') raise TypeError('missing dsn and no parameters')
else: else:
dsn = " ".join(["%s=%s" % (k, _param_escape(str(v))) dsn = " ".join(["%s=%s" % (k, _param_escape(str(v)))
for (k, v) in items]) for (k, v) in items])
conn = _connect(dsn, connection_factory=connection_factory, async=async) conn = _connect(dsn, connection_factory=connection_factory, async=async)
if cursor_factory is not None: if cursor_factory is not None:

View File

@ -34,7 +34,7 @@ from psycopg2._psycopg import new_type, new_array_type, register_type
# import the best json implementation available # import the best json implementation available
if sys.version_info[:2] >= (2,6): if sys.version_info[:2] >= (2, 6):
import json import json
else: else:
try: try:
@ -51,6 +51,7 @@ JSONARRAY_OID = 199
JSONB_OID = 3802 JSONB_OID = 3802
JSONBARRAY_OID = 3807 JSONBARRAY_OID = 3807
class Json(object): class Json(object):
""" """
An `~psycopg2.extensions.ISQLQuote` wrapper to adapt a Python object to An `~psycopg2.extensions.ISQLQuote` wrapper to adapt a Python object to
@ -106,7 +107,7 @@ class Json(object):
def register_json(conn_or_curs=None, globally=False, loads=None, def register_json(conn_or_curs=None, globally=False, loads=None,
oid=None, array_oid=None, name='json'): oid=None, array_oid=None, name='json'):
"""Create and register typecasters converting :sql:`json` type to Python objects. """Create and register typecasters converting :sql:`json` type to Python objects.
:param conn_or_curs: a connection or cursor used to find the :sql:`json` :param conn_or_curs: a connection or cursor used to find the :sql:`json`
@ -143,6 +144,7 @@ def register_json(conn_or_curs=None, globally=False, loads=None,
return JSON, JSONARRAY return JSON, JSONARRAY
def register_default_json(conn_or_curs=None, globally=False, loads=None): def register_default_json(conn_or_curs=None, globally=False, loads=None):
""" """
Create and register :sql:`json` typecasters for PostgreSQL 9.2 and following. Create and register :sql:`json` typecasters for PostgreSQL 9.2 and following.
@ -153,7 +155,8 @@ def register_default_json(conn_or_curs=None, globally=False, loads=None):
All the parameters have the same meaning of `register_json()`. All the parameters have the same meaning of `register_json()`.
""" """
return register_json(conn_or_curs=conn_or_curs, globally=globally, return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSON_OID, array_oid=JSONARRAY_OID) loads=loads, oid=JSON_OID, array_oid=JSONARRAY_OID)
def register_default_jsonb(conn_or_curs=None, globally=False, loads=None): def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
""" """
@ -165,7 +168,9 @@ def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
meaning of `register_json()`. meaning of `register_json()`.
""" """
return register_json(conn_or_curs=conn_or_curs, globally=globally, return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSONB_OID, array_oid=JSONBARRAY_OID, name='jsonb') loads=loads, oid=JSONB_OID, array_oid=JSONBARRAY_OID,
name='jsonb')
def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'): def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
"""Create typecasters for json data type.""" """Create typecasters for json data type."""
@ -188,6 +193,7 @@ def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
return JSON, JSONARRAY return JSON, JSONARRAY
def _get_json_oids(conn_or_curs, name='json'): def _get_json_oids(conn_or_curs, name='json'):
# lazy imports # lazy imports
from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extensions import STATUS_IN_TRANSACTION
@ -202,9 +208,8 @@ def _get_json_oids(conn_or_curs, name='json'):
typarray = conn.server_version >= 80300 and "typarray" or "NULL" typarray = conn.server_version >= 80300 and "typarray" or "NULL"
# get the oid for the hstore # get the oid for the hstore
curs.execute( curs.execute("SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;" % typarray, (name,))
% typarray, (name,))
r = curs.fetchone() r = curs.fetchone()
# revert the status of the connection as before the command # revert the status of the connection as before the command
@ -215,6 +220,3 @@ def _get_json_oids(conn_or_curs, name='json'):
raise conn.ProgrammingError("%s data type not found" % name) raise conn.ProgrammingError("%s data type not found" % name)
return r return r

View File

@ -30,6 +30,7 @@ from psycopg2._psycopg import ProgrammingError, InterfaceError
from psycopg2.extensions import ISQLQuote, adapt, register_adapter, b from psycopg2.extensions import ISQLQuote, adapt, register_adapter, b
from psycopg2.extensions import new_type, new_array_type, register_type from psycopg2.extensions import new_type, new_array_type, register_type
class Range(object): class Range(object):
"""Python representation for a PostgreSQL |range|_ type. """Python representation for a PostgreSQL |range|_ type.
@ -57,7 +58,8 @@ class Range(object):
if self._bounds is None: if self._bounds is None:
return "%s(empty=True)" % self.__class__.__name__ return "%s(empty=True)" % self.__class__.__name__
else: else:
return "%s(%r, %r, %r)" % (self.__class__.__name__, return "%s(%r, %r, %r)" % (
self.__class__.__name__,
self._lower, self._upper, self._bounds) self._lower, self._upper, self._bounds)
@property @property
@ -124,8 +126,8 @@ class Range(object):
if not isinstance(other, Range): if not isinstance(other, Range):
return False return False
return (self._lower == other._lower return (self._lower == other._lower
and self._upper == other._upper and self._upper == other._upper
and self._bounds == other._bounds) and self._bounds == other._bounds)
def __ne__(self, other): def __ne__(self, other):
return not self.__eq__(other) return not self.__eq__(other)
@ -248,7 +250,7 @@ class RangeAdapter(object):
upper = b('NULL') upper = b('NULL')
return b(self.name + '(') + lower + b(', ') + upper \ return b(self.name + '(') + lower + b(', ') + upper \
+ b(", '%s')" % r._bounds) + b(", '%s')" % r._bounds)
class RangeCaster(object): class RangeCaster(object):
@ -318,7 +320,7 @@ class RangeCaster(object):
if conn.server_version < 90200: if conn.server_version < 90200:
raise ProgrammingError("range types not available in version %s" raise ProgrammingError("range types not available in version %s"
% conn.server_version) % conn.server_version)
# Store the transaction status of the connection to revert it after use # Store the transaction status of the connection to revert it after use
conn_status = conn.status conn_status = conn.status
@ -349,8 +351,7 @@ where typname = %s and ns.nspname = %s;
rec = curs.fetchone() rec = curs.fetchone()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != STATUS_IN_TRANSACTION if (conn_status != STATUS_IN_TRANSACTION and not conn.autocommit):
and not conn.autocommit):
conn.rollback() conn.rollback()
if not rec: if not rec:
@ -359,8 +360,8 @@ where typname = %s and ns.nspname = %s;
type, subtype, array = rec type, subtype, array = rec
return RangeCaster(name, pyrange, return RangeCaster(
oid=type, subtype_oid=subtype, array_oid=array) name, pyrange, oid=type, subtype_oid=subtype, array_oid=array)
_re_range = re.compile(r""" _re_range = re.compile(r"""
( \(|\[ ) # lower bound flag ( \(|\[ ) # lower bound flag
@ -425,14 +426,17 @@ class NumericRange(Range):
""" """
pass pass
class DateRange(Range): class DateRange(Range):
"""Represents :sql:`daterange` values.""" """Represents :sql:`daterange` values."""
pass pass
class DateTimeRange(Range): class DateTimeRange(Range):
"""Represents :sql:`tsrange` values.""" """Represents :sql:`tsrange` values."""
pass pass
class DateTimeTZRange(Range): class DateTimeTZRange(Range):
"""Represents :sql:`tstzrange` values.""" """Represents :sql:`tstzrange` values."""
pass pass
@ -475,27 +479,25 @@ register_adapter(NumericRange, NumberRangeAdapter)
# note: the adapter is registered more than once, but this is harmless. # note: the adapter is registered more than once, but this is harmless.
int4range_caster = RangeCaster(NumberRangeAdapter, NumericRange, int4range_caster = RangeCaster(NumberRangeAdapter, NumericRange,
oid=3904, subtype_oid=23, array_oid=3905) oid=3904, subtype_oid=23, array_oid=3905)
int4range_caster._register() int4range_caster._register()
int8range_caster = RangeCaster(NumberRangeAdapter, NumericRange, int8range_caster = RangeCaster(NumberRangeAdapter, NumericRange,
oid=3926, subtype_oid=20, array_oid=3927) oid=3926, subtype_oid=20, array_oid=3927)
int8range_caster._register() int8range_caster._register()
numrange_caster = RangeCaster(NumberRangeAdapter, NumericRange, numrange_caster = RangeCaster(NumberRangeAdapter, NumericRange,
oid=3906, subtype_oid=1700, array_oid=3907) oid=3906, subtype_oid=1700, array_oid=3907)
numrange_caster._register() numrange_caster._register()
daterange_caster = RangeCaster('daterange', DateRange, daterange_caster = RangeCaster('daterange', DateRange,
oid=3912, subtype_oid=1082, array_oid=3913) oid=3912, subtype_oid=1082, array_oid=3913)
daterange_caster._register() daterange_caster._register()
tsrange_caster = RangeCaster('tsrange', DateTimeRange, tsrange_caster = RangeCaster('tsrange', DateTimeRange,
oid=3908, subtype_oid=1114, array_oid=3909) oid=3908, subtype_oid=1114, array_oid=3909)
tsrange_caster._register() tsrange_caster._register()
tstzrange_caster = RangeCaster('tstzrange', DateTimeTZRange, tstzrange_caster = RangeCaster('tstzrange', DateTimeTZRange,
oid=3910, subtype_oid=1184, array_oid=3911) oid=3910, subtype_oid=1184, array_oid=3911)
tstzrange_caster._register() tstzrange_caster._register()

View File

@ -29,6 +29,7 @@ This module contains symbolic names for all PostgreSQL error codes.
# http://www.postgresql.org/docs/current/static/errcodes-appendix.html # http://www.postgresql.org/docs/current/static/errcodes-appendix.html
# #
def lookup(code, _cache={}): def lookup(code, _cache={}):
"""Lookup an error code or class code and return its symbolic name. """Lookup an error code or class code and return its symbolic name.

View File

@ -56,7 +56,9 @@ try:
except ImportError: except ImportError:
pass pass
from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version, parse_dsn, quote_ident from psycopg2._psycopg import (
adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version,
parse_dsn, quote_ident)
from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type
from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column
@ -68,32 +70,32 @@ except ImportError:
pass pass
"""Isolation level values.""" """Isolation level values."""
ISOLATION_LEVEL_AUTOCOMMIT = 0 ISOLATION_LEVEL_AUTOCOMMIT = 0
ISOLATION_LEVEL_READ_UNCOMMITTED = 4 ISOLATION_LEVEL_READ_UNCOMMITTED = 4
ISOLATION_LEVEL_READ_COMMITTED = 1 ISOLATION_LEVEL_READ_COMMITTED = 1
ISOLATION_LEVEL_REPEATABLE_READ = 2 ISOLATION_LEVEL_REPEATABLE_READ = 2
ISOLATION_LEVEL_SERIALIZABLE = 3 ISOLATION_LEVEL_SERIALIZABLE = 3
"""psycopg connection status values.""" """psycopg connection status values."""
STATUS_SETUP = 0 STATUS_SETUP = 0
STATUS_READY = 1 STATUS_READY = 1
STATUS_BEGIN = 2 STATUS_BEGIN = 2
STATUS_SYNC = 3 # currently unused STATUS_SYNC = 3 # currently unused
STATUS_ASYNC = 4 # currently unused STATUS_ASYNC = 4 # currently unused
STATUS_PREPARED = 5 STATUS_PREPARED = 5
# This is a useful mnemonic to check if the connection is in a transaction # This is a useful mnemonic to check if the connection is in a transaction
STATUS_IN_TRANSACTION = STATUS_BEGIN STATUS_IN_TRANSACTION = STATUS_BEGIN
"""psycopg asynchronous connection polling values""" """psycopg asynchronous connection polling values"""
POLL_OK = 0 POLL_OK = 0
POLL_READ = 1 POLL_READ = 1
POLL_WRITE = 2 POLL_WRITE = 2
POLL_ERROR = 3 POLL_ERROR = 3
"""Backend transaction status values.""" """Backend transaction status values."""
TRANSACTION_STATUS_IDLE = 0 TRANSACTION_STATUS_IDLE = 0
TRANSACTION_STATUS_ACTIVE = 1 TRANSACTION_STATUS_ACTIVE = 1
TRANSACTION_STATUS_INTRANS = 2 TRANSACTION_STATUS_INTRANS = 2
TRANSACTION_STATUS_INERROR = 3 TRANSACTION_STATUS_INERROR = 3
TRANSACTION_STATUS_UNKNOWN = 4 TRANSACTION_STATUS_UNKNOWN = 4
@ -108,6 +110,7 @@ else:
def b(s): def b(s):
return s.encode('utf8') return s.encode('utf8')
def register_adapter(typ, callable): def register_adapter(typ, callable):
"""Register 'callable' as an ISQLQuote adapter for type 'typ'.""" """Register 'callable' as an ISQLQuote adapter for type 'typ'."""
adapters[(typ, ISQLQuote)] = callable adapters[(typ, ISQLQuote)] = callable

View File

@ -106,6 +106,7 @@ class DictConnection(_connection):
kwargs.setdefault('cursor_factory', DictCursor) kwargs.setdefault('cursor_factory', DictCursor)
return super(DictConnection, self).cursor(*args, **kwargs) return super(DictConnection, self).cursor(*args, **kwargs)
class DictCursor(DictCursorBase): class DictCursor(DictCursorBase):
"""A cursor that keeps a list of column name -> index mappings.""" """A cursor that keeps a list of column name -> index mappings."""
@ -130,6 +131,7 @@ class DictCursor(DictCursorBase):
self.index[self.description[i][0]] = i self.index[self.description[i][0]] = i
self._query_executed = 0 self._query_executed = 0
class DictRow(list): class DictRow(list):
"""A row object that allow by-column-name access to data.""" """A row object that allow by-column-name access to data."""
@ -192,10 +194,10 @@ class DictRow(list):
# drop the crusty Py2 methods # drop the crusty Py2 methods
if _sys.version_info[0] > 2: if _sys.version_info[0] > 2:
items = iteritems; del iteritems items = iteritems
keys = iterkeys; del iterkeys keys = iterkeys
values = itervalues; del itervalues values = itervalues
del has_key del itervalues, has_key, iteritems, iterkeys
class RealDictConnection(_connection): class RealDictConnection(_connection):
@ -204,6 +206,7 @@ class RealDictConnection(_connection):
kwargs.setdefault('cursor_factory', RealDictCursor) kwargs.setdefault('cursor_factory', RealDictCursor)
return super(RealDictConnection, self).cursor(*args, **kwargs) return super(RealDictConnection, self).cursor(*args, **kwargs)
class RealDictCursor(DictCursorBase): class RealDictCursor(DictCursorBase):
"""A cursor that uses a real dict as the base type for rows. """A cursor that uses a real dict as the base type for rows.
@ -233,6 +236,7 @@ class RealDictCursor(DictCursorBase):
self.column_mapping.append(self.description[i][0]) self.column_mapping.append(self.description[i][0])
self._query_executed = 0 self._query_executed = 0
class RealDictRow(dict): class RealDictRow(dict):
"""A `!dict` subclass representing a data record.""" """A `!dict` subclass representing a data record."""
@ -265,6 +269,7 @@ class NamedTupleConnection(_connection):
kwargs.setdefault('cursor_factory', NamedTupleCursor) kwargs.setdefault('cursor_factory', NamedTupleCursor)
return super(NamedTupleConnection, self).cursor(*args, **kwargs) return super(NamedTupleConnection, self).cursor(*args, **kwargs)
class NamedTupleCursor(_cursor): class NamedTupleCursor(_cursor):
"""A cursor that generates results as `~collections.namedtuple`. """A cursor that generates results as `~collections.namedtuple`.
@ -369,11 +374,13 @@ class LoggingConnection(_connection):
def _logtofile(self, msg, curs): def _logtofile(self, msg, curs):
msg = self.filter(msg, curs) msg = self.filter(msg, curs)
if msg: self._logobj.write(msg + _os.linesep) if msg:
self._logobj.write(msg + _os.linesep)
def _logtologger(self, msg, curs): def _logtologger(self, msg, curs):
msg = self.filter(msg, curs) msg = self.filter(msg, curs)
if msg: self._logobj.debug(msg) if msg:
self._logobj.debug(msg)
def _check(self): def _check(self):
if not hasattr(self, '_logobj'): if not hasattr(self, '_logobj'):
@ -385,6 +392,7 @@ class LoggingConnection(_connection):
kwargs.setdefault('cursor_factory', LoggingCursor) kwargs.setdefault('cursor_factory', LoggingCursor)
return super(LoggingConnection, self).cursor(*args, **kwargs) return super(LoggingConnection, self).cursor(*args, **kwargs)
class LoggingCursor(_cursor): class LoggingCursor(_cursor):
"""A cursor that logs queries using its connection logging facilities.""" """A cursor that logs queries using its connection logging facilities."""
@ -425,6 +433,7 @@ class MinTimeLoggingConnection(LoggingConnection):
kwargs.setdefault('cursor_factory', MinTimeLoggingCursor) kwargs.setdefault('cursor_factory', MinTimeLoggingCursor)
return LoggingConnection.cursor(self, *args, **kwargs) return LoggingConnection.cursor(self, *args, **kwargs)
class MinTimeLoggingCursor(LoggingCursor): class MinTimeLoggingCursor(LoggingCursor):
"""The cursor sub-class companion to `MinTimeLoggingConnection`.""" """The cursor sub-class companion to `MinTimeLoggingConnection`."""
@ -459,6 +468,7 @@ class UUID_adapter(object):
def __str__(self): def __str__(self):
return "'%s'::uuid" % self._uuid return "'%s'::uuid" % self._uuid
def register_uuid(oids=None, conn_or_curs=None): def register_uuid(oids=None, conn_or_curs=None):
"""Create the UUID type and an uuid.UUID adapter. """Create the UUID type and an uuid.UUID adapter.
@ -480,8 +490,9 @@ def register_uuid(oids=None, conn_or_curs=None):
oid1 = oids oid1 = oids
oid2 = 2951 oid2 = 2951
_ext.UUID = _ext.new_type((oid1, ), "UUID", _ext.UUID = _ext.new_type(
lambda data, cursor: data and uuid.UUID(data) or None) (oid1, ), "UUID",
lambda data, cursor: data and uuid.UUID(data) or None)
_ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID) _ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID)
_ext.register_type(_ext.UUID, conn_or_curs) _ext.register_type(_ext.UUID, conn_or_curs)
@ -523,6 +534,7 @@ class Inet(object):
def __str__(self): def __str__(self):
return str(self.addr) return str(self.addr)
def register_inet(oid=None, conn_or_curs=None): def register_inet(oid=None, conn_or_curs=None):
"""Create the INET type and an Inet adapter. """Create the INET type and an Inet adapter.
@ -542,7 +554,7 @@ def register_inet(oid=None, conn_or_curs=None):
oid2 = 1041 oid2 = 1041
_ext.INET = _ext.new_type((oid1, ), "INET", _ext.INET = _ext.new_type((oid1, ), "INET",
lambda data, cursor: data and Inet(data) or None) lambda data, cursor: data and Inet(data) or None)
_ext.INETARRAY = _ext.new_array_type((oid2, ), "INETARRAY", _ext.INET) _ext.INETARRAY = _ext.new_array_type((oid2, ), "INETARRAY", _ext.INET)
_ext.register_type(_ext.INET, conn_or_curs) _ext.register_type(_ext.INET, conn_or_curs)
@ -736,14 +748,14 @@ WHERE typname = 'hstore';
rv1.append(oids[1]) rv1.append(oids[1])
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION and not conn.autocommit):
and not conn.autocommit): conn.rollback()
conn.rollback()
return tuple(rv0), tuple(rv1) return tuple(rv0), tuple(rv1)
def register_hstore(conn_or_curs, globally=False, unicode=False, def register_hstore(conn_or_curs, globally=False, unicode=False,
oid=None, array_oid=None): oid=None, array_oid=None):
"""Register adapter and typecaster for `!dict`\-\ |hstore| conversions. """Register adapter and typecaster for `!dict`\-\ |hstore| conversions.
:param conn_or_curs: a connection or cursor: the typecaster will be :param conn_or_curs: a connection or cursor: the typecaster will be
@ -822,8 +834,8 @@ class CompositeCaster(object):
self.oid = oid self.oid = oid
self.array_oid = array_oid self.array_oid = array_oid
self.attnames = [ a[0] for a in attrs ] self.attnames = [a[0] for a in attrs]
self.atttypes = [ a[1] for a in attrs ] self.atttypes = [a[1] for a in attrs]
self._create_type(name, self.attnames) self._create_type(name, self.attnames)
self.typecaster = _ext.new_type((oid,), name, self.parse) self.typecaster = _ext.new_type((oid,), name, self.parse)
if array_oid: if array_oid:
@ -842,8 +854,8 @@ class CompositeCaster(object):
"expecting %d components for the type %s, %d found instead" % "expecting %d components for the type %s, %d found instead" %
(len(self.atttypes), self.name, len(tokens))) (len(self.atttypes), self.name, len(tokens)))
values = [ curs.cast(oid, token) values = [curs.cast(oid, token)
for oid, token in zip(self.atttypes, tokens) ] for oid, token in zip(self.atttypes, tokens)]
return self.make(values) return self.make(values)
@ -927,8 +939,7 @@ ORDER BY attnum;
recs = curs.fetchall() recs = curs.fetchall()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION and not conn.autocommit):
and not conn.autocommit):
conn.rollback() conn.rollback()
if not recs: if not recs:
@ -937,10 +948,11 @@ ORDER BY attnum;
type_oid = recs[0][0] type_oid = recs[0][0]
array_oid = recs[0][1] array_oid = recs[0][1]
type_attrs = [ (r[2], r[3]) for r in recs ] type_attrs = [(r[2], r[3]) for r in recs]
return self(tname, type_oid, type_attrs, return self(tname, type_oid, type_attrs,
array_oid=array_oid, schema=schema) array_oid=array_oid, schema=schema)
def register_composite(name, conn_or_curs, globally=False, factory=None): def register_composite(name, conn_or_curs, globally=False, factory=None):
"""Register a typecaster to convert a composite type into a tuple. """Register a typecaster to convert a composite type into a tuple.

View File

@ -51,7 +51,7 @@ class AbstractConnectionPool(object):
self._pool = [] self._pool = []
self._used = {} self._used = {}
self._rused = {} # id(conn) -> key map self._rused = {} # id(conn) -> key map
self._keys = 0 self._keys = 0
for i in range(self.minconn): for i in range(self.minconn):
@ -74,8 +74,10 @@ class AbstractConnectionPool(object):
def _getconn(self, key=None): def _getconn(self, key=None):
"""Get a free connection and assign it to 'key' if not None.""" """Get a free connection and assign it to 'key' if not None."""
if self.closed: raise PoolError("connection pool is closed") if self.closed:
if key is None: key = self._getkey() raise PoolError("connection pool is closed")
if key is None:
key = self._getkey()
if key in self._used: if key in self._used:
return self._used[key] return self._used[key]
@ -91,8 +93,10 @@ class AbstractConnectionPool(object):
def _putconn(self, conn, key=None, close=False): def _putconn(self, conn, key=None, close=False):
"""Put away a connection.""" """Put away a connection."""
if self.closed: raise PoolError("connection pool is closed") if self.closed:
if key is None: key = self._rused.get(id(conn)) raise PoolError("connection pool is closed")
if key is None:
key = self._rused.get(id(conn))
if not key: if not key:
raise PoolError("trying to put unkeyed connection") raise PoolError("trying to put unkeyed connection")
@ -129,7 +133,8 @@ class AbstractConnectionPool(object):
an already closed connection. If you call .closeall() make sure an already closed connection. If you call .closeall() make sure
your code can deal with it. your code can deal with it.
""" """
if self.closed: raise PoolError("connection pool is closed") if self.closed:
raise PoolError("connection pool is closed")
for conn in self._pool + list(self._used.values()): for conn in self._pool + list(self._used.values()):
try: try:
conn.close() conn.close()
@ -143,7 +148,7 @@ class SimpleConnectionPool(AbstractConnectionPool):
getconn = AbstractConnectionPool._getconn getconn = AbstractConnectionPool._getconn
putconn = AbstractConnectionPool._putconn putconn = AbstractConnectionPool._putconn
closeall = AbstractConnectionPool._closeall closeall = AbstractConnectionPool._closeall
class ThreadedConnectionPool(AbstractConnectionPool): class ThreadedConnectionPool(AbstractConnectionPool):
@ -195,7 +200,7 @@ class PersistentConnectionPool(AbstractConnectionPool):
"""Initialize the threading lock.""" """Initialize the threading lock."""
import warnings import warnings
warnings.warn("deprecated: use ZPsycopgDA.pool implementation", warnings.warn("deprecated: use ZPsycopgDA.pool implementation",
DeprecationWarning) DeprecationWarning)
import threading import threading
AbstractConnectionPool.__init__( AbstractConnectionPool.__init__(
@ -204,7 +209,7 @@ class PersistentConnectionPool(AbstractConnectionPool):
# we we'll need the thread module, to determine thread ids, so we # we we'll need the thread module, to determine thread ids, so we
# import it here and copy it in an instance variable # import it here and copy it in an instance variable
import thread as _thread # work around for 2to3 bug - see ticket #348 import thread as _thread # work around for 2to3 bug - see ticket #348
self.__thread = _thread self.__thread = _thread
def getconn(self): def getconn(self):
@ -221,7 +226,8 @@ class PersistentConnectionPool(AbstractConnectionPool):
key = self.__thread.get_ident() key = self.__thread.get_ident()
self._lock.acquire() self._lock.acquire()
try: try:
if not conn: conn = self._used[key] if not conn:
conn = self._used[key]
self._putconn(conn, key, close) self._putconn(conn, key, close)
finally: finally:
self._lock.release() self._lock.release()

View File

@ -36,6 +36,7 @@ from psycopg2 import *
import psycopg2.extensions as _ext import psycopg2.extensions as _ext
_2connect = connect _2connect = connect
def connect(*args, **kwargs): def connect(*args, **kwargs):
"""connect(dsn, ...) -> new psycopg 1.1.x compatible connection object""" """connect(dsn, ...) -> new psycopg 1.1.x compatible connection object"""
kwargs['connection_factory'] = connection kwargs['connection_factory'] = connection
@ -43,6 +44,7 @@ def connect(*args, **kwargs):
conn.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED) conn.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED)
return conn return conn
class connection(_2connection): class connection(_2connection):
"""psycopg 1.1.x connection.""" """psycopg 1.1.x connection."""
@ -92,4 +94,3 @@ class cursor(_2cursor):
for row in rows: for row in rows:
res.append(self.__build_dict(row)) res.append(self.__build_dict(row))
return res return res

View File

@ -31,6 +31,7 @@ import time
ZERO = datetime.timedelta(0) ZERO = datetime.timedelta(0)
class FixedOffsetTimezone(datetime.tzinfo): class FixedOffsetTimezone(datetime.tzinfo):
"""Fixed offset in minutes east from UTC. """Fixed offset in minutes east from UTC.
@ -52,7 +53,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
def __init__(self, offset=None, name=None): def __init__(self, offset=None, name=None):
if offset is not None: if offset is not None:
self._offset = datetime.timedelta(minutes = offset) self._offset = datetime.timedelta(minutes=offset)
if name is not None: if name is not None:
self._name = name self._name = name
@ -85,7 +86,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
else: else:
seconds = self._offset.seconds + self._offset.days * 86400 seconds = self._offset.seconds + self._offset.days * 86400
hours, seconds = divmod(seconds, 3600) hours, seconds = divmod(seconds, 3600)
minutes = seconds/60 minutes = seconds / 60
if minutes: if minutes:
return "%+03d:%d" % (hours, minutes) return "%+03d:%d" % (hours, minutes)
else: else:
@ -95,13 +96,14 @@ class FixedOffsetTimezone(datetime.tzinfo):
return ZERO return ZERO
STDOFFSET = datetime.timedelta(seconds = -time.timezone) STDOFFSET = datetime.timedelta(seconds=-time.timezone)
if time.daylight: if time.daylight:
DSTOFFSET = datetime.timedelta(seconds = -time.altzone) DSTOFFSET = datetime.timedelta(seconds=-time.altzone)
else: else:
DSTOFFSET = STDOFFSET DSTOFFSET = STDOFFSET
DSTDIFF = DSTOFFSET - STDOFFSET DSTDIFF = DSTOFFSET - STDOFFSET
class LocalTimezone(datetime.tzinfo): class LocalTimezone(datetime.tzinfo):
"""Platform idea of local timezone. """Platform idea of local timezone.