This commit is contained in:
Mads Jensen 2015-12-26 20:49:48 +00:00
commit e2f37eee26
9 changed files with 131 additions and 102 deletions

View File

@ -82,13 +82,14 @@ else:
import re import re
def _param_escape(s,
re_escape=re.compile(r"([\\'])"), def _param_escape(s, re_escape=re.compile(r"([\\'])"),
re_space=re.compile(r'\s')): re_space=re.compile(r'\s')):
""" """
Apply the escaping rule required by PQconnectdb Apply the escaping rule required by PQconnectdb
""" """
if not s: return "''" if not s:
return "''"
s = re_escape.sub(r'\\\1', s) s = re_escape.sub(r'\\\1', s)
if re_space.search(s): if re_space.search(s):
@ -99,8 +100,7 @@ def _param_escape(s,
del re del re
def connect(dsn=None, def connect(dsn=None, database=None, user=None, password=None, host=None, port=None,
database=None, user=None, password=None, host=None, port=None,
connection_factory=None, cursor_factory=None, async=False, **kwargs): connection_factory=None, cursor_factory=None, async=False, **kwargs):
""" """
Create a new database connection. Create a new database connection.

View File

@ -34,7 +34,7 @@ from psycopg2._psycopg import new_type, new_array_type, register_type
# import the best json implementation available # import the best json implementation available
if sys.version_info[:2] >= (2,6): if sys.version_info[:2] >= (2, 6):
import json import json
else: else:
try: try:
@ -51,6 +51,7 @@ JSONARRAY_OID = 199
JSONB_OID = 3802 JSONB_OID = 3802
JSONBARRAY_OID = 3807 JSONBARRAY_OID = 3807
class Json(object): class Json(object):
""" """
An `~psycopg2.extensions.ISQLQuote` wrapper to adapt a Python object to An `~psycopg2.extensions.ISQLQuote` wrapper to adapt a Python object to
@ -143,6 +144,7 @@ def register_json(conn_or_curs=None, globally=False, loads=None,
return JSON, JSONARRAY return JSON, JSONARRAY
def register_default_json(conn_or_curs=None, globally=False, loads=None): def register_default_json(conn_or_curs=None, globally=False, loads=None):
""" """
Create and register :sql:`json` typecasters for PostgreSQL 9.2 and following. Create and register :sql:`json` typecasters for PostgreSQL 9.2 and following.
@ -155,6 +157,7 @@ def register_default_json(conn_or_curs=None, globally=False, loads=None):
return register_json(conn_or_curs=conn_or_curs, globally=globally, return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSON_OID, array_oid=JSONARRAY_OID) loads=loads, oid=JSON_OID, array_oid=JSONARRAY_OID)
def register_default_jsonb(conn_or_curs=None, globally=False, loads=None): def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
""" """
Create and register :sql:`jsonb` typecasters for PostgreSQL 9.4 and following. Create and register :sql:`jsonb` typecasters for PostgreSQL 9.4 and following.
@ -165,7 +168,9 @@ def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
meaning of `register_json()`. meaning of `register_json()`.
""" """
return register_json(conn_or_curs=conn_or_curs, globally=globally, return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSONB_OID, array_oid=JSONBARRAY_OID, name='jsonb') loads=loads, oid=JSONB_OID, array_oid=JSONBARRAY_OID,
name='jsonb')
def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'): def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
"""Create typecasters for json data type.""" """Create typecasters for json data type."""
@ -188,6 +193,7 @@ def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
return JSON, JSONARRAY return JSON, JSONARRAY
def _get_json_oids(conn_or_curs, name='json'): def _get_json_oids(conn_or_curs, name='json'):
# lazy imports # lazy imports
from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extensions import STATUS_IN_TRANSACTION
@ -202,8 +208,7 @@ def _get_json_oids(conn_or_curs, name='json'):
typarray = conn.server_version >= 80300 and "typarray" or "NULL" typarray = conn.server_version >= 80300 and "typarray" or "NULL"
# get the oid for the hstore # get the oid for the hstore
curs.execute( curs.execute("SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
% typarray, (name,)) % typarray, (name,))
r = curs.fetchone() r = curs.fetchone()
@ -215,6 +220,3 @@ def _get_json_oids(conn_or_curs, name='json'):
raise conn.ProgrammingError("%s data type not found" % name) raise conn.ProgrammingError("%s data type not found" % name)
return r return r

View File

@ -30,6 +30,7 @@ from psycopg2._psycopg import ProgrammingError, InterfaceError
from psycopg2.extensions import ISQLQuote, adapt, register_adapter, b from psycopg2.extensions import ISQLQuote, adapt, register_adapter, b
from psycopg2.extensions import new_type, new_array_type, register_type from psycopg2.extensions import new_type, new_array_type, register_type
class Range(object): class Range(object):
"""Python representation for a PostgreSQL |range|_ type. """Python representation for a PostgreSQL |range|_ type.
@ -57,7 +58,8 @@ class Range(object):
if self._bounds is None: if self._bounds is None:
return "%s(empty=True)" % self.__class__.__name__ return "%s(empty=True)" % self.__class__.__name__
else: else:
return "%s(%r, %r, %r)" % (self.__class__.__name__, return "%s(%r, %r, %r)" % (
self.__class__.__name__,
self._lower, self._upper, self._bounds) self._lower, self._upper, self._bounds)
@property @property
@ -349,8 +351,7 @@ where typname = %s and ns.nspname = %s;
rec = curs.fetchone() rec = curs.fetchone()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != STATUS_IN_TRANSACTION if (conn_status != STATUS_IN_TRANSACTION and not conn.autocommit):
and not conn.autocommit):
conn.rollback() conn.rollback()
if not rec: if not rec:
@ -359,8 +360,8 @@ where typname = %s and ns.nspname = %s;
type, subtype, array = rec type, subtype, array = rec
return RangeCaster(name, pyrange, return RangeCaster(
oid=type, subtype_oid=subtype, array_oid=array) name, pyrange, oid=type, subtype_oid=subtype, array_oid=array)
_re_range = re.compile(r""" _re_range = re.compile(r"""
( \(|\[ ) # lower bound flag ( \(|\[ ) # lower bound flag
@ -425,14 +426,17 @@ class NumericRange(Range):
""" """
pass pass
class DateRange(Range): class DateRange(Range):
"""Represents :sql:`daterange` values.""" """Represents :sql:`daterange` values."""
pass pass
class DateTimeRange(Range): class DateTimeRange(Range):
"""Represents :sql:`tsrange` values.""" """Represents :sql:`tsrange` values."""
pass pass
class DateTimeTZRange(Range): class DateTimeTZRange(Range):
"""Represents :sql:`tstzrange` values.""" """Represents :sql:`tstzrange` values."""
pass pass
@ -497,5 +501,3 @@ tsrange_caster._register()
tstzrange_caster = RangeCaster('tstzrange', DateTimeTZRange, tstzrange_caster = RangeCaster('tstzrange', DateTimeTZRange,
oid=3910, subtype_oid=1184, array_oid=3911) oid=3910, subtype_oid=1184, array_oid=3911)
tstzrange_caster._register() tstzrange_caster._register()

View File

@ -29,6 +29,7 @@ This module contains symbolic names for all PostgreSQL error codes.
# http://www.postgresql.org/docs/current/static/errcodes-appendix.html # http://www.postgresql.org/docs/current/static/errcodes-appendix.html
# #
def lookup(code, _cache={}): def lookup(code, _cache={}):
"""Lookup an error code or class code and return its symbolic name. """Lookup an error code or class code and return its symbolic name.

View File

@ -56,7 +56,9 @@ try:
except ImportError: except ImportError:
pass pass
from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version, parse_dsn, quote_ident from psycopg2._psycopg import (
adapt, adapters, encodings, connection, cursor, lobject, Xid, libpq_version,
parse_dsn, quote_ident)
from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type
from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column
@ -108,6 +110,7 @@ else:
def b(s): def b(s):
return s.encode('utf8') return s.encode('utf8')
def register_adapter(typ, callable): def register_adapter(typ, callable):
"""Register 'callable' as an ISQLQuote adapter for type 'typ'.""" """Register 'callable' as an ISQLQuote adapter for type 'typ'."""
adapters[(typ, ISQLQuote)] = callable adapters[(typ, ISQLQuote)] = callable

View File

@ -106,6 +106,7 @@ class DictConnection(_connection):
kwargs.setdefault('cursor_factory', DictCursor) kwargs.setdefault('cursor_factory', DictCursor)
return super(DictConnection, self).cursor(*args, **kwargs) return super(DictConnection, self).cursor(*args, **kwargs)
class DictCursor(DictCursorBase): class DictCursor(DictCursorBase):
"""A cursor that keeps a list of column name -> index mappings.""" """A cursor that keeps a list of column name -> index mappings."""
@ -130,6 +131,7 @@ class DictCursor(DictCursorBase):
self.index[self.description[i][0]] = i self.index[self.description[i][0]] = i
self._query_executed = 0 self._query_executed = 0
class DictRow(list): class DictRow(list):
"""A row object that allow by-column-name access to data.""" """A row object that allow by-column-name access to data."""
@ -192,10 +194,10 @@ class DictRow(list):
# drop the crusty Py2 methods # drop the crusty Py2 methods
if _sys.version_info[0] > 2: if _sys.version_info[0] > 2:
items = iteritems; del iteritems items = iteritems
keys = iterkeys; del iterkeys keys = iterkeys
values = itervalues; del itervalues values = itervalues
del has_key del itervalues, has_key, iteritems, iterkeys
class RealDictConnection(_connection): class RealDictConnection(_connection):
@ -204,6 +206,7 @@ class RealDictConnection(_connection):
kwargs.setdefault('cursor_factory', RealDictCursor) kwargs.setdefault('cursor_factory', RealDictCursor)
return super(RealDictConnection, self).cursor(*args, **kwargs) return super(RealDictConnection, self).cursor(*args, **kwargs)
class RealDictCursor(DictCursorBase): class RealDictCursor(DictCursorBase):
"""A cursor that uses a real dict as the base type for rows. """A cursor that uses a real dict as the base type for rows.
@ -233,6 +236,7 @@ class RealDictCursor(DictCursorBase):
self.column_mapping.append(self.description[i][0]) self.column_mapping.append(self.description[i][0])
self._query_executed = 0 self._query_executed = 0
class RealDictRow(dict): class RealDictRow(dict):
"""A `!dict` subclass representing a data record.""" """A `!dict` subclass representing a data record."""
@ -265,6 +269,7 @@ class NamedTupleConnection(_connection):
kwargs.setdefault('cursor_factory', NamedTupleCursor) kwargs.setdefault('cursor_factory', NamedTupleCursor)
return super(NamedTupleConnection, self).cursor(*args, **kwargs) return super(NamedTupleConnection, self).cursor(*args, **kwargs)
class NamedTupleCursor(_cursor): class NamedTupleCursor(_cursor):
"""A cursor that generates results as `~collections.namedtuple`. """A cursor that generates results as `~collections.namedtuple`.
@ -369,11 +374,13 @@ class LoggingConnection(_connection):
def _logtofile(self, msg, curs): def _logtofile(self, msg, curs):
msg = self.filter(msg, curs) msg = self.filter(msg, curs)
if msg: self._logobj.write(msg + _os.linesep) if msg:
self._logobj.write(msg + _os.linesep)
def _logtologger(self, msg, curs): def _logtologger(self, msg, curs):
msg = self.filter(msg, curs) msg = self.filter(msg, curs)
if msg: self._logobj.debug(msg) if msg:
self._logobj.debug(msg)
def _check(self): def _check(self):
if not hasattr(self, '_logobj'): if not hasattr(self, '_logobj'):
@ -385,6 +392,7 @@ class LoggingConnection(_connection):
kwargs.setdefault('cursor_factory', LoggingCursor) kwargs.setdefault('cursor_factory', LoggingCursor)
return super(LoggingConnection, self).cursor(*args, **kwargs) return super(LoggingConnection, self).cursor(*args, **kwargs)
class LoggingCursor(_cursor): class LoggingCursor(_cursor):
"""A cursor that logs queries using its connection logging facilities.""" """A cursor that logs queries using its connection logging facilities."""
@ -425,6 +433,7 @@ class MinTimeLoggingConnection(LoggingConnection):
kwargs.setdefault('cursor_factory', MinTimeLoggingCursor) kwargs.setdefault('cursor_factory', MinTimeLoggingCursor)
return LoggingConnection.cursor(self, *args, **kwargs) return LoggingConnection.cursor(self, *args, **kwargs)
class MinTimeLoggingCursor(LoggingCursor): class MinTimeLoggingCursor(LoggingCursor):
"""The cursor sub-class companion to `MinTimeLoggingConnection`.""" """The cursor sub-class companion to `MinTimeLoggingConnection`."""
@ -459,6 +468,7 @@ class UUID_adapter(object):
def __str__(self): def __str__(self):
return "'%s'::uuid" % self._uuid return "'%s'::uuid" % self._uuid
def register_uuid(oids=None, conn_or_curs=None): def register_uuid(oids=None, conn_or_curs=None):
"""Create the UUID type and an uuid.UUID adapter. """Create the UUID type and an uuid.UUID adapter.
@ -480,7 +490,8 @@ def register_uuid(oids=None, conn_or_curs=None):
oid1 = oids oid1 = oids
oid2 = 2951 oid2 = 2951
_ext.UUID = _ext.new_type((oid1, ), "UUID", _ext.UUID = _ext.new_type(
(oid1, ), "UUID",
lambda data, cursor: data and uuid.UUID(data) or None) lambda data, cursor: data and uuid.UUID(data) or None)
_ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID) _ext.UUIDARRAY = _ext.new_array_type((oid2,), "UUID[]", _ext.UUID)
@ -523,6 +534,7 @@ class Inet(object):
def __str__(self): def __str__(self):
return str(self.addr) return str(self.addr)
def register_inet(oid=None, conn_or_curs=None): def register_inet(oid=None, conn_or_curs=None):
"""Create the INET type and an Inet adapter. """Create the INET type and an Inet adapter.
@ -736,12 +748,12 @@ WHERE typname = 'hstore';
rv1.append(oids[1]) rv1.append(oids[1])
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION and not conn.autocommit):
and not conn.autocommit):
conn.rollback() conn.rollback()
return tuple(rv0), tuple(rv1) return tuple(rv0), tuple(rv1)
def register_hstore(conn_or_curs, globally=False, unicode=False, def register_hstore(conn_or_curs, globally=False, unicode=False,
oid=None, array_oid=None): oid=None, array_oid=None):
"""Register adapter and typecaster for `!dict`\-\ |hstore| conversions. """Register adapter and typecaster for `!dict`\-\ |hstore| conversions.
@ -822,8 +834,8 @@ class CompositeCaster(object):
self.oid = oid self.oid = oid
self.array_oid = array_oid self.array_oid = array_oid
self.attnames = [ a[0] for a in attrs ] self.attnames = [a[0] for a in attrs]
self.atttypes = [ a[1] for a in attrs ] self.atttypes = [a[1] for a in attrs]
self._create_type(name, self.attnames) self._create_type(name, self.attnames)
self.typecaster = _ext.new_type((oid,), name, self.parse) self.typecaster = _ext.new_type((oid,), name, self.parse)
if array_oid: if array_oid:
@ -842,8 +854,8 @@ class CompositeCaster(object):
"expecting %d components for the type %s, %d found instead" % "expecting %d components for the type %s, %d found instead" %
(len(self.atttypes), self.name, len(tokens))) (len(self.atttypes), self.name, len(tokens)))
values = [ curs.cast(oid, token) values = [curs.cast(oid, token)
for oid, token in zip(self.atttypes, tokens) ] for oid, token in zip(self.atttypes, tokens)]
return self.make(values) return self.make(values)
@ -927,8 +939,7 @@ ORDER BY attnum;
recs = curs.fetchall() recs = curs.fetchall()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION and not conn.autocommit):
and not conn.autocommit):
conn.rollback() conn.rollback()
if not recs: if not recs:
@ -937,11 +948,12 @@ ORDER BY attnum;
type_oid = recs[0][0] type_oid = recs[0][0]
array_oid = recs[0][1] array_oid = recs[0][1]
type_attrs = [ (r[2], r[3]) for r in recs ] type_attrs = [(r[2], r[3]) for r in recs]
return self(tname, type_oid, type_attrs, return self(tname, type_oid, type_attrs,
array_oid=array_oid, schema=schema) array_oid=array_oid, schema=schema)
def register_composite(name, conn_or_curs, globally=False, factory=None): def register_composite(name, conn_or_curs, globally=False, factory=None):
"""Register a typecaster to convert a composite type into a tuple. """Register a typecaster to convert a composite type into a tuple.

View File

@ -74,8 +74,10 @@ class AbstractConnectionPool(object):
def _getconn(self, key=None): def _getconn(self, key=None):
"""Get a free connection and assign it to 'key' if not None.""" """Get a free connection and assign it to 'key' if not None."""
if self.closed: raise PoolError("connection pool is closed") if self.closed:
if key is None: key = self._getkey() raise PoolError("connection pool is closed")
if key is None:
key = self._getkey()
if key in self._used: if key in self._used:
return self._used[key] return self._used[key]
@ -91,8 +93,10 @@ class AbstractConnectionPool(object):
def _putconn(self, conn, key=None, close=False): def _putconn(self, conn, key=None, close=False):
"""Put away a connection.""" """Put away a connection."""
if self.closed: raise PoolError("connection pool is closed") if self.closed:
if key is None: key = self._rused.get(id(conn)) raise PoolError("connection pool is closed")
if key is None:
key = self._rused.get(id(conn))
if not key: if not key:
raise PoolError("trying to put unkeyed connection") raise PoolError("trying to put unkeyed connection")
@ -129,7 +133,8 @@ class AbstractConnectionPool(object):
an already closed connection. If you call .closeall() make sure an already closed connection. If you call .closeall() make sure
your code can deal with it. your code can deal with it.
""" """
if self.closed: raise PoolError("connection pool is closed") if self.closed:
raise PoolError("connection pool is closed")
for conn in self._pool + list(self._used.values()): for conn in self._pool + list(self._used.values()):
try: try:
conn.close() conn.close()
@ -221,7 +226,8 @@ class PersistentConnectionPool(AbstractConnectionPool):
key = self.__thread.get_ident() key = self.__thread.get_ident()
self._lock.acquire() self._lock.acquire()
try: try:
if not conn: conn = self._used[key] if not conn:
conn = self._used[key]
self._putconn(conn, key, close) self._putconn(conn, key, close)
finally: finally:
self._lock.release() self._lock.release()

View File

@ -36,6 +36,7 @@ from psycopg2 import *
import psycopg2.extensions as _ext import psycopg2.extensions as _ext
_2connect = connect _2connect = connect
def connect(*args, **kwargs): def connect(*args, **kwargs):
"""connect(dsn, ...) -> new psycopg 1.1.x compatible connection object""" """connect(dsn, ...) -> new psycopg 1.1.x compatible connection object"""
kwargs['connection_factory'] = connection kwargs['connection_factory'] = connection
@ -43,6 +44,7 @@ def connect(*args, **kwargs):
conn.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED) conn.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED)
return conn return conn
class connection(_2connection): class connection(_2connection):
"""psycopg 1.1.x connection.""" """psycopg 1.1.x connection."""
@ -92,4 +94,3 @@ class cursor(_2cursor):
for row in rows: for row in rows:
res.append(self.__build_dict(row)) res.append(self.__build_dict(row))
return res return res

View File

@ -31,6 +31,7 @@ import time
ZERO = datetime.timedelta(0) ZERO = datetime.timedelta(0)
class FixedOffsetTimezone(datetime.tzinfo): class FixedOffsetTimezone(datetime.tzinfo):
"""Fixed offset in minutes east from UTC. """Fixed offset in minutes east from UTC.
@ -52,7 +53,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
def __init__(self, offset=None, name=None): def __init__(self, offset=None, name=None):
if offset is not None: if offset is not None:
self._offset = datetime.timedelta(minutes = offset) self._offset = datetime.timedelta(minutes=offset)
if name is not None: if name is not None:
self._name = name self._name = name
@ -85,7 +86,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
else: else:
seconds = self._offset.seconds + self._offset.days * 86400 seconds = self._offset.seconds + self._offset.days * 86400
hours, seconds = divmod(seconds, 3600) hours, seconds = divmod(seconds, 3600)
minutes = seconds/60 minutes = seconds / 60
if minutes: if minutes:
return "%+03d:%d" % (hours, minutes) return "%+03d:%d" % (hours, minutes)
else: else:
@ -95,13 +96,14 @@ class FixedOffsetTimezone(datetime.tzinfo):
return ZERO return ZERO
STDOFFSET = datetime.timedelta(seconds = -time.timezone) STDOFFSET = datetime.timedelta(seconds=-time.timezone)
if time.daylight: if time.daylight:
DSTOFFSET = datetime.timedelta(seconds = -time.altzone) DSTOFFSET = datetime.timedelta(seconds=-time.altzone)
else: else:
DSTOFFSET = STDOFFSET DSTOFFSET = STDOFFSET
DSTDIFF = DSTOFFSET - STDOFFSET DSTDIFF = DSTOFFSET - STDOFFSET
class LocalTimezone(datetime.tzinfo): class LocalTimezone(datetime.tzinfo):
"""Platform idea of local timezone. """Platform idea of local timezone.