Python source cleanup using flake8

This commit is contained in:
Daniele Varrazzo 2016-10-11 00:10:53 +01:00
parent 4458c9b4c9
commit 91d2158de7
35 changed files with 644 additions and 432 deletions

View File

@ -47,19 +47,20 @@ Homepage: http://initd.org/projects/psycopg2
# Import the DBAPI-2.0 stuff into top-level module.
from psycopg2._psycopg import BINARY, NUMBER, STRING, DATETIME, ROWID
from psycopg2._psycopg import ( # noqa
BINARY, NUMBER, STRING, DATETIME, ROWID,
from psycopg2._psycopg import Binary, Date, Time, Timestamp
from psycopg2._psycopg import DateFromTicks, TimeFromTicks, TimestampFromTicks
Binary, Date, Time, Timestamp,
DateFromTicks, TimeFromTicks, TimestampFromTicks,
from psycopg2._psycopg import Error, Warning, DataError, DatabaseError, ProgrammingError
from psycopg2._psycopg import IntegrityError, InterfaceError, InternalError
from psycopg2._psycopg import NotSupportedError, OperationalError
Error, Warning, DataError, DatabaseError, ProgrammingError, IntegrityError,
InterfaceError, InternalError, NotSupportedError, OperationalError,
from psycopg2._psycopg import _connect, apilevel, threadsafety, paramstyle
from psycopg2._psycopg import __version__, __libpq_version__
_connect, apilevel, threadsafety, paramstyle,
__version__, __libpq_version__,
)
from psycopg2 import tz
from psycopg2 import tz # noqa
# Register default adapters.
@ -82,7 +83,7 @@ else:
def connect(dsn=None, connection_factory=None, cursor_factory=None,
async=False, **kwargs):
async=False, **kwargs):
"""
Create a new database connection.

View File

@ -34,7 +34,7 @@ from psycopg2._psycopg import new_type, new_array_type, register_type
# import the best json implementation available
if sys.version_info[:2] >= (2,6):
if sys.version_info[:2] >= (2, 6):
import json
else:
try:
@ -51,6 +51,7 @@ JSONARRAY_OID = 199
JSONB_OID = 3802
JSONBARRAY_OID = 3807
class Json(object):
"""
An `~psycopg2.extensions.ISQLQuote` wrapper to adapt a Python object to
@ -106,7 +107,7 @@ class Json(object):
def register_json(conn_or_curs=None, globally=False, loads=None,
oid=None, array_oid=None, name='json'):
oid=None, array_oid=None, name='json'):
"""Create and register typecasters converting :sql:`json` type to Python objects.
:param conn_or_curs: a connection or cursor used to find the :sql:`json`
@ -143,6 +144,7 @@ def register_json(conn_or_curs=None, globally=False, loads=None,
return JSON, JSONARRAY
def register_default_json(conn_or_curs=None, globally=False, loads=None):
"""
Create and register :sql:`json` typecasters for PostgreSQL 9.2 and following.
@ -155,6 +157,7 @@ def register_default_json(conn_or_curs=None, globally=False, loads=None):
return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSON_OID, array_oid=JSONARRAY_OID)
def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
"""
Create and register :sql:`jsonb` typecasters for PostgreSQL 9.4 and following.
@ -167,6 +170,7 @@ def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSONB_OID, array_oid=JSONBARRAY_OID, name='jsonb')
def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
"""Create typecasters for json data type."""
if loads is None:
@ -188,6 +192,7 @@ def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
return JSON, JSONARRAY
def _get_json_oids(conn_or_curs, name='json'):
# lazy imports
from psycopg2.extensions import STATUS_IN_TRANSACTION
@ -204,7 +209,7 @@ def _get_json_oids(conn_or_curs, name='json'):
# get the oid for the hstore
curs.execute(
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
% typarray, (name,))
% typarray, (name,))
r = curs.fetchone()
# revert the status of the connection as before the command
@ -215,6 +220,3 @@ def _get_json_oids(conn_or_curs, name='json'):
raise conn.ProgrammingError("%s data type not found" % name)
return r

View File

@ -30,6 +30,7 @@ from psycopg2._psycopg import ProgrammingError, InterfaceError
from psycopg2.extensions import ISQLQuote, adapt, register_adapter
from psycopg2.extensions import new_type, new_array_type, register_type
class Range(object):
"""Python representation for a PostgreSQL |range|_ type.
@ -78,42 +79,50 @@ class Range(object):
@property
def lower_inf(self):
"""`!True` if the range doesn't have a lower bound."""
if self._bounds is None: return False
if self._bounds is None:
return False
return self._lower is None
@property
def upper_inf(self):
"""`!True` if the range doesn't have an upper bound."""
if self._bounds is None: return False
if self._bounds is None:
return False
return self._upper is None
@property
def lower_inc(self):
"""`!True` if the lower bound is included in the range."""
if self._bounds is None: return False
if self._lower is None: return False
if self._bounds is None or self._lower is None:
return False
return self._bounds[0] == '['
@property
def upper_inc(self):
"""`!True` if the upper bound is included in the range."""
if self._bounds is None: return False
if self._upper is None: return False
if self._bounds is None or self._upper is None:
return False
return self._bounds[1] == ']'
def __contains__(self, x):
if self._bounds is None: return False
if self._bounds is None:
return False
if self._lower is not None:
if self._bounds[0] == '[':
if x < self._lower: return False
if x < self._lower:
return False
else:
if x <= self._lower: return False
if x <= self._lower:
return False
if self._upper is not None:
if self._bounds[1] == ']':
if x > self._upper: return False
if x > self._upper:
return False
else:
if x >= self._upper: return False
if x >= self._upper:
return False
return True
@ -295,7 +304,8 @@ class RangeCaster(object):
self.adapter.name = pgrange
else:
try:
if issubclass(pgrange, RangeAdapter) and pgrange is not RangeAdapter:
if issubclass(pgrange, RangeAdapter) \
and pgrange is not RangeAdapter:
self.adapter = pgrange
except TypeError:
pass
@ -436,14 +446,17 @@ class NumericRange(Range):
"""
pass
class DateRange(Range):
"""Represents :sql:`daterange` values."""
pass
class DateTimeRange(Range):
"""Represents :sql:`tsrange` values."""
pass
class DateTimeTZRange(Range):
"""Represents :sql:`tstzrange` values."""
pass
@ -508,5 +521,3 @@ tsrange_caster._register()
tstzrange_caster = RangeCaster('tstzrange', DateTimeTZRange,
oid=3910, subtype_oid=1184, array_oid=3911)
tstzrange_caster._register()

View File

@ -29,6 +29,7 @@ This module contains symbolic names for all PostgreSQL error codes.
# http://www.postgresql.org/docs/current/static/errcodes-appendix.html
#
def lookup(code, _cache={}):
"""Lookup an error code or class code and return its symbolic name.

View File

@ -33,71 +33,69 @@ This module holds all the extensions to the DBAPI-2.0 provided by psycopg.
# License for more details.
import re as _re
import sys as _sys
from psycopg2._psycopg import UNICODE, INTEGER, LONGINTEGER, BOOLEAN, FLOAT
from psycopg2._psycopg import TIME, DATE, INTERVAL, DECIMAL
from psycopg2._psycopg import BINARYARRAY, BOOLEANARRAY, DATEARRAY, DATETIMEARRAY
from psycopg2._psycopg import DECIMALARRAY, FLOATARRAY, INTEGERARRAY, INTERVALARRAY
from psycopg2._psycopg import LONGINTEGERARRAY, ROWIDARRAY, STRINGARRAY, TIMEARRAY
from psycopg2._psycopg import UNICODEARRAY
from psycopg2._psycopg import ( # noqa
BINARYARRAY, BOOLEAN, BOOLEANARRAY, DATE, DATEARRAY, DATETIMEARRAY,
DECIMAL, DECIMALARRAY, FLOAT, FLOATARRAY, INTEGER, INTEGERARRAY,
INTERVAL, INTERVALARRAY, LONGINTEGER, LONGINTEGERARRAY, ROWIDARRAY,
STRINGARRAY, TIME, TIMEARRAY, UNICODE, UNICODEARRAY,
AsIs, Binary, Boolean, Float, Int, QuotedString, )
from psycopg2._psycopg import Binary, Boolean, Int, Float, QuotedString, AsIs
try:
from psycopg2._psycopg import MXDATE, MXDATETIME, MXINTERVAL, MXTIME
from psycopg2._psycopg import MXDATEARRAY, MXDATETIMEARRAY, MXINTERVALARRAY, MXTIMEARRAY
from psycopg2._psycopg import DateFromMx, TimeFromMx, TimestampFromMx
from psycopg2._psycopg import IntervalFromMx
from psycopg2._psycopg import ( # noqa
MXDATE, MXDATETIME, MXINTERVAL, MXTIME,
MXDATEARRAY, MXDATETIMEARRAY, MXINTERVALARRAY, MXTIMEARRAY,
DateFromMx, TimeFromMx, TimestampFromMx, IntervalFromMx, )
except ImportError:
pass
try:
from psycopg2._psycopg import PYDATE, PYDATETIME, PYINTERVAL, PYTIME
from psycopg2._psycopg import PYDATEARRAY, PYDATETIMEARRAY, PYINTERVALARRAY, PYTIMEARRAY
from psycopg2._psycopg import DateFromPy, TimeFromPy, TimestampFromPy
from psycopg2._psycopg import IntervalFromPy
from psycopg2._psycopg import ( # noqa
PYDATE, PYDATETIME, PYINTERVAL, PYTIME,
PYDATEARRAY, PYDATETIMEARRAY, PYINTERVALARRAY, PYTIMEARRAY,
DateFromPy, TimeFromPy, TimestampFromPy, IntervalFromPy, )
except ImportError:
pass
from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor
from psycopg2._psycopg import lobject, Xid, libpq_version, parse_dsn, quote_ident
from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type
from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column
from psycopg2._psycopg import ( # noqa
adapt, adapters, encodings, connection, cursor,
lobject, Xid, libpq_version, parse_dsn, quote_ident,
string_types, binary_types, new_type, new_array_type, register_type,
ISQLQuote, Notify, Diagnostics, Column,
QueryCanceledError, TransactionRollbackError,
set_wait_callback, get_wait_callback, )
from psycopg2._psycopg import QueryCanceledError, TransactionRollbackError
try:
from psycopg2._psycopg import set_wait_callback, get_wait_callback
except ImportError:
pass
"""Isolation level values."""
ISOLATION_LEVEL_AUTOCOMMIT = 0
ISOLATION_LEVEL_READ_UNCOMMITTED = 4
ISOLATION_LEVEL_READ_COMMITTED = 1
ISOLATION_LEVEL_REPEATABLE_READ = 2
ISOLATION_LEVEL_SERIALIZABLE = 3
ISOLATION_LEVEL_AUTOCOMMIT = 0
ISOLATION_LEVEL_READ_UNCOMMITTED = 4
ISOLATION_LEVEL_READ_COMMITTED = 1
ISOLATION_LEVEL_REPEATABLE_READ = 2
ISOLATION_LEVEL_SERIALIZABLE = 3
"""psycopg connection status values."""
STATUS_SETUP = 0
STATUS_READY = 1
STATUS_BEGIN = 2
STATUS_SYNC = 3 # currently unused
STATUS_ASYNC = 4 # currently unused
STATUS_SETUP = 0
STATUS_READY = 1
STATUS_BEGIN = 2
STATUS_SYNC = 3 # currently unused
STATUS_ASYNC = 4 # currently unused
STATUS_PREPARED = 5
# This is a useful mnemonic to check if the connection is in a transaction
STATUS_IN_TRANSACTION = STATUS_BEGIN
"""psycopg asynchronous connection polling values"""
POLL_OK = 0
POLL_READ = 1
POLL_OK = 0
POLL_READ = 1
POLL_WRITE = 2
POLL_ERROR = 3
"""Backend transaction status values."""
TRANSACTION_STATUS_IDLE = 0
TRANSACTION_STATUS_ACTIVE = 1
TRANSACTION_STATUS_IDLE = 0
TRANSACTION_STATUS_ACTIVE = 1
TRANSACTION_STATUS_INTRANS = 2
TRANSACTION_STATUS_INERROR = 3
TRANSACTION_STATUS_UNKNOWN = 4
@ -194,7 +192,7 @@ def _param_escape(s,
# Create default json typecasters for PostgreSQL 9.2 oids
from psycopg2._json import register_default_json, register_default_jsonb
from psycopg2._json import register_default_json, register_default_jsonb # noqa
try:
JSON, JSONARRAY = register_default_json()
@ -206,7 +204,7 @@ del register_default_json, register_default_jsonb
# Create default Range typecasters
from psycopg2. _range import Range
from psycopg2. _range import Range # noqa
del Range

View File

@ -40,10 +40,23 @@ from psycopg2 import extensions as _ext
from psycopg2.extensions import cursor as _cursor
from psycopg2.extensions import connection as _connection
from psycopg2.extensions import adapt as _A, quote_ident
from psycopg2._psycopg import REPLICATION_PHYSICAL, REPLICATION_LOGICAL
from psycopg2._psycopg import ReplicationConnection as _replicationConnection
from psycopg2._psycopg import ReplicationCursor as _replicationCursor
from psycopg2._psycopg import ReplicationMessage
from psycopg2._psycopg import ( # noqa
REPLICATION_PHYSICAL, REPLICATION_LOGICAL,
ReplicationConnection as _replicationConnection,
ReplicationCursor as _replicationCursor,
ReplicationMessage)
# expose the json adaptation stuff into the module
from psycopg2._json import ( # noqa
json, Json, register_json, register_default_json, register_default_jsonb)
# Expose range-related objects
from psycopg2._range import ( # noqa
Range, NumericRange, DateRange, DateTimeRange, DateTimeTZRange,
register_range, RangeAdapter, RangeCaster)
class DictCursorBase(_cursor):
@ -109,6 +122,7 @@ class DictConnection(_connection):
kwargs.setdefault('cursor_factory', DictCursor)
return super(DictConnection, self).cursor(*args, **kwargs)
class DictCursor(DictCursorBase):
"""A cursor that keeps a list of column name -> index mappings."""
@ -133,6 +147,7 @@ class DictCursor(DictCursorBase):
self.index[self.description[i][0]] = i
self._query_executed = 0
class DictRow(list):
"""A row object that allow by-column-name access to data."""
@ -195,10 +210,10 @@ class DictRow(list):
# drop the crusty Py2 methods
if _sys.version_info[0] > 2:
items = iteritems; del iteritems
keys = iterkeys; del iterkeys
values = itervalues; del itervalues
del has_key
items = iteritems # noqa
keys = iterkeys # noqa
values = itervalues # noqa
del iteritems, iterkeys, itervalues, has_key
class RealDictConnection(_connection):
@ -207,6 +222,7 @@ class RealDictConnection(_connection):
kwargs.setdefault('cursor_factory', RealDictCursor)
return super(RealDictConnection, self).cursor(*args, **kwargs)
class RealDictCursor(DictCursorBase):
"""A cursor that uses a real dict as the base type for rows.
@ -236,6 +252,7 @@ class RealDictCursor(DictCursorBase):
self.column_mapping.append(self.description[i][0])
self._query_executed = 0
class RealDictRow(dict):
"""A `!dict` subclass representing a data record."""
@ -268,6 +285,7 @@ class NamedTupleConnection(_connection):
kwargs.setdefault('cursor_factory', NamedTupleCursor)
return super(NamedTupleConnection, self).cursor(*args, **kwargs)
class NamedTupleCursor(_cursor):
"""A cursor that generates results as `~collections.namedtuple`.
@ -372,11 +390,13 @@ class LoggingConnection(_connection):
def _logtofile(self, msg, curs):
msg = self.filter(msg, curs)
if msg: self._logobj.write(msg + _os.linesep)
if msg:
self._logobj.write(msg + _os.linesep)
def _logtologger(self, msg, curs):
msg = self.filter(msg, curs)
if msg: self._logobj.debug(msg)
if msg:
self._logobj.debug(msg)
def _check(self):
if not hasattr(self, '_logobj'):
@ -388,6 +408,7 @@ class LoggingConnection(_connection):
kwargs.setdefault('cursor_factory', LoggingCursor)
return super(LoggingConnection, self).cursor(*args, **kwargs)
class LoggingCursor(_cursor):
"""A cursor that logs queries using its connection logging facilities."""
@ -428,6 +449,7 @@ class MinTimeLoggingConnection(LoggingConnection):
kwargs.setdefault('cursor_factory', MinTimeLoggingCursor)
return LoggingConnection.cursor(self, *args, **kwargs)
class MinTimeLoggingCursor(LoggingCursor):
"""The cursor sub-class companion to `MinTimeLoggingConnection`."""
@ -479,18 +501,23 @@ class ReplicationCursor(_replicationCursor):
if slot_type == REPLICATION_LOGICAL:
if output_plugin is None:
raise psycopg2.ProgrammingError("output plugin name is required to create logical replication slot")
raise psycopg2.ProgrammingError(
"output plugin name is required to create "
"logical replication slot")
command += "LOGICAL %s" % quote_ident(output_plugin, self)
elif slot_type == REPLICATION_PHYSICAL:
if output_plugin is not None:
raise psycopg2.ProgrammingError("cannot specify output plugin name when creating physical replication slot")
raise psycopg2.ProgrammingError(
"cannot specify output plugin name when creating "
"physical replication slot")
command += "PHYSICAL"
else:
raise psycopg2.ProgrammingError("unrecognized replication type: %s" % repr(slot_type))
raise psycopg2.ProgrammingError(
"unrecognized replication type: %s" % repr(slot_type))
self.execute(command)
@ -513,7 +540,8 @@ class ReplicationCursor(_replicationCursor):
if slot_name:
command += "SLOT %s " % quote_ident(slot_name, self)
else:
raise psycopg2.ProgrammingError("slot name is required for logical replication")
raise psycopg2.ProgrammingError(
"slot name is required for logical replication")
command += "LOGICAL "
@ -523,28 +551,32 @@ class ReplicationCursor(_replicationCursor):
# don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX
else:
raise psycopg2.ProgrammingError("unrecognized replication type: %s" % repr(slot_type))
raise psycopg2.ProgrammingError(
"unrecognized replication type: %s" % repr(slot_type))
if type(start_lsn) is str:
lsn = start_lsn.split('/')
lsn = "%X/%08X" % (int(lsn[0], 16), int(lsn[1], 16))
else:
lsn = "%X/%08X" % ((start_lsn >> 32) & 0xFFFFFFFF, start_lsn & 0xFFFFFFFF)
lsn = "%X/%08X" % ((start_lsn >> 32) & 0xFFFFFFFF,
start_lsn & 0xFFFFFFFF)
command += lsn
if timeline != 0:
if slot_type == REPLICATION_LOGICAL:
raise psycopg2.ProgrammingError("cannot specify timeline for logical replication")
raise psycopg2.ProgrammingError(
"cannot specify timeline for logical replication")
command += " TIMELINE %d" % timeline
if options:
if slot_type == REPLICATION_PHYSICAL:
raise psycopg2.ProgrammingError("cannot specify output plugin options for physical replication")
raise psycopg2.ProgrammingError(
"cannot specify output plugin options for physical replication")
command += " ("
for k,v in options.iteritems():
for k, v in options.iteritems():
if not command.endswith('('):
command += ", "
command += "%s %s" % (quote_ident(k, self), _A(str(v)))
@ -579,6 +611,7 @@ class UUID_adapter(object):
def __str__(self):
return "'%s'::uuid" % self._uuid
def register_uuid(oids=None, conn_or_curs=None):
"""Create the UUID type and an uuid.UUID adapter.
@ -643,6 +676,7 @@ class Inet(object):
def __str__(self):
return str(self.addr)
def register_inet(oid=None, conn_or_curs=None):
"""Create the INET type and an Inet adapter.
@ -862,8 +896,9 @@ WHERE typname = 'hstore';
return tuple(rv0), tuple(rv1)
def register_hstore(conn_or_curs, globally=False, unicode=False,
oid=None, array_oid=None):
oid=None, array_oid=None):
"""Register adapter and typecaster for `!dict`\-\ |hstore| conversions.
:param conn_or_curs: a connection or cursor: the typecaster will be
@ -942,8 +977,8 @@ class CompositeCaster(object):
self.oid = oid
self.array_oid = array_oid
self.attnames = [ a[0] for a in attrs ]
self.atttypes = [ a[1] for a in attrs ]
self.attnames = [a[0] for a in attrs]
self.atttypes = [a[1] for a in attrs]
self._create_type(name, self.attnames)
self.typecaster = _ext.new_type((oid,), name, self.parse)
if array_oid:
@ -962,8 +997,8 @@ class CompositeCaster(object):
"expecting %d components for the type %s, %d found instead" %
(len(self.atttypes), self.name, len(tokens)))
values = [ curs.cast(oid, token)
for oid, token in zip(self.atttypes, tokens) ]
values = [curs.cast(oid, token)
for oid, token in zip(self.atttypes, tokens)]
return self.make(values)
@ -1057,11 +1092,12 @@ ORDER BY attnum;
type_oid = recs[0][0]
array_oid = recs[0][1]
type_attrs = [ (r[2], r[3]) for r in recs ]
type_attrs = [(r[2], r[3]) for r in recs]
return self(tname, type_oid, type_attrs,
array_oid=array_oid, schema=schema)
def register_composite(name, conn_or_curs, globally=False, factory=None):
"""Register a typecaster to convert a composite type into a tuple.
@ -1084,17 +1120,7 @@ def register_composite(name, conn_or_curs, globally=False, factory=None):
_ext.register_type(caster.typecaster, not globally and conn_or_curs or None)
if caster.array_typecaster is not None:
_ext.register_type(caster.array_typecaster, not globally and conn_or_curs or None)
_ext.register_type(
caster.array_typecaster, not globally and conn_or_curs or None)
return caster
# expose the json adaptation stuff into the module
from psycopg2._json import json, Json, register_json
from psycopg2._json import register_default_json, register_default_jsonb
# Expose range-related objects
from psycopg2._range import Range, NumericRange
from psycopg2._range import DateRange, DateTimeRange, DateTimeTZRange
from psycopg2._range import register_range, RangeAdapter, RangeCaster

View File

@ -40,18 +40,18 @@ class AbstractConnectionPool(object):
New 'minconn' connections are created immediately calling 'connfunc'
with given parameters. The connection pool will support a maximum of
about 'maxconn' connections.
about 'maxconn' connections.
"""
self.minconn = int(minconn)
self.maxconn = int(maxconn)
self.closed = False
self._args = args
self._kwargs = kwargs
self._pool = []
self._used = {}
self._rused = {} # id(conn) -> key map
self._rused = {} # id(conn) -> key map
self._keys = 0
for i in range(self.minconn):
@ -71,12 +71,14 @@ class AbstractConnectionPool(object):
"""Return a new unique key."""
self._keys += 1
return self._keys
def _getconn(self, key=None):
"""Get a free connection and assign it to 'key' if not None."""
if self.closed: raise PoolError("connection pool is closed")
if key is None: key = self._getkey()
if self.closed:
raise PoolError("connection pool is closed")
if key is None:
key = self._getkey()
if key in self._used:
return self._used[key]
@ -88,11 +90,13 @@ class AbstractConnectionPool(object):
if len(self._used) == self.maxconn:
raise PoolError("connection pool exhausted")
return self._connect(key)
def _putconn(self, conn, key=None, close=False):
"""Put away a connection."""
if self.closed: raise PoolError("connection pool is closed")
if key is None: key = self._rused.get(id(conn))
if self.closed:
raise PoolError("connection pool is closed")
if key is None:
key = self._rused.get(id(conn))
if not key:
raise PoolError("trying to put unkeyed connection")
@ -129,21 +133,22 @@ class AbstractConnectionPool(object):
an already closed connection. If you call .closeall() make sure
your code can deal with it.
"""
if self.closed: raise PoolError("connection pool is closed")
if self.closed:
raise PoolError("connection pool is closed")
for conn in self._pool + list(self._used.values()):
try:
conn.close()
except:
pass
self.closed = True
class SimpleConnectionPool(AbstractConnectionPool):
"""A connection pool that can't be shared across different threads."""
getconn = AbstractConnectionPool._getconn
putconn = AbstractConnectionPool._putconn
closeall = AbstractConnectionPool._closeall
closeall = AbstractConnectionPool._closeall
class ThreadedConnectionPool(AbstractConnectionPool):
@ -182,7 +187,7 @@ class ThreadedConnectionPool(AbstractConnectionPool):
class PersistentConnectionPool(AbstractConnectionPool):
"""A pool that assigns persistent connections to different threads.
"""A pool that assigns persistent connections to different threads.
Note that this connection pool generates by itself the required keys
using the current thread id. This means that until a thread puts away
@ -204,7 +209,7 @@ class PersistentConnectionPool(AbstractConnectionPool):
# we we'll need the thread module, to determine thread ids, so we
# import it here and copy it in an instance variable
import thread as _thread # work around for 2to3 bug - see ticket #348
import thread as _thread # work around for 2to3 bug - see ticket #348
self.__thread = _thread
def getconn(self):
@ -221,7 +226,8 @@ class PersistentConnectionPool(AbstractConnectionPool):
key = self.__thread.get_ident()
self._lock.acquire()
try:
if not conn: conn = self._used[key]
if not conn:
conn = self._used[key]
self._putconn(conn, key, close)
finally:
self._lock.release()

View File

@ -28,24 +28,26 @@ old code while porting to psycopg 2. Import it as follows::
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import psycopg2._psycopg as _2psycopg
import psycopg2._psycopg as _2psycopg # noqa
from psycopg2.extensions import cursor as _2cursor
from psycopg2.extensions import connection as _2connection
from psycopg2 import *
from psycopg2 import * # noqa
import psycopg2.extensions as _ext
_2connect = connect
def connect(*args, **kwargs):
"""connect(dsn, ...) -> new psycopg 1.1.x compatible connection object"""
kwargs['connection_factory'] = connection
conn = _2connect(*args, **kwargs)
conn.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED)
return conn
class connection(_2connection):
"""psycopg 1.1.x connection."""
def cursor(self):
"""cursor() -> new psycopg 1.1.x compatible cursor object"""
return _2connection.cursor(self, cursor_factory=cursor)
@ -56,7 +58,7 @@ class connection(_2connection):
self.set_isolation_level(_ext.ISOLATION_LEVEL_AUTOCOMMIT)
else:
self.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED)
class cursor(_2cursor):
"""psycopg 1.1.x cursor.
@ -71,25 +73,24 @@ class cursor(_2cursor):
for i in range(len(self.description)):
res[self.description[i][0]] = row[i]
return res
def dictfetchone(self):
row = _2cursor.fetchone(self)
if row:
return self.__build_dict(row)
else:
return row
def dictfetchmany(self, size):
res = []
rows = _2cursor.fetchmany(self, size)
for row in rows:
res.append(self.__build_dict(row))
return res
def dictfetchall(self):
res = []
rows = _2cursor.fetchall(self)
for row in rows:
res.append(self.__build_dict(row))
return res

View File

@ -2,7 +2,7 @@
This module holds two different tzinfo implementations that can be used as
the 'tzinfo' argument to datetime constructors, directly passed to psycopg
functions or used to set the .tzinfo_factory attribute in cursors.
functions or used to set the .tzinfo_factory attribute in cursors.
"""
# psycopg/tz.py - tzinfo implementation
#
@ -31,6 +31,7 @@ import time
ZERO = datetime.timedelta(0)
class FixedOffsetTimezone(datetime.tzinfo):
"""Fixed offset in minutes east from UTC.
@ -52,7 +53,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
def __init__(self, offset=None, name=None):
if offset is not None:
self._offset = datetime.timedelta(minutes = offset)
self._offset = datetime.timedelta(minutes=offset)
if name is not None:
self._name = name
@ -85,7 +86,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
else:
seconds = self._offset.seconds + self._offset.days * 86400
hours, seconds = divmod(seconds, 3600)
minutes = seconds/60
minutes = seconds / 60
if minutes:
return "%+03d:%d" % (hours, minutes)
else:
@ -95,13 +96,14 @@ class FixedOffsetTimezone(datetime.tzinfo):
return ZERO
STDOFFSET = datetime.timedelta(seconds = -time.timezone)
STDOFFSET = datetime.timedelta(seconds=-time.timezone)
if time.daylight:
DSTOFFSET = datetime.timedelta(seconds = -time.altzone)
DSTOFFSET = datetime.timedelta(seconds=-time.altzone)
else:
DSTOFFSET = STDOFFSET
DSTDIFF = DSTOFFSET - STDOFFSET
class LocalTimezone(datetime.tzinfo):
"""Platform idea of local timezone.

View File

@ -19,8 +19,8 @@
# code defines the DBAPITypeObject fundamental types and warns for
# undefined types.
import sys, os, string, copy
from string import split, join, strip
import sys
from string import split, strip
# here is the list of the foundamental types we want to import from
@ -37,7 +37,7 @@ basic_types = (['NUMBER', ['INT8', 'INT4', 'INT2', 'FLOAT8', 'FLOAT4',
['STRING', ['NAME', 'CHAR', 'TEXT', 'BPCHAR',
'VARCHAR']],
['BOOLEAN', ['BOOL']],
['DATETIME', ['TIMESTAMP', 'TIMESTAMPTZ',
['DATETIME', ['TIMESTAMP', 'TIMESTAMPTZ',
'TINTERVAL', 'INTERVAL']],
['TIME', ['TIME', 'TIMETZ']],
['DATE', ['DATE']],
@ -73,8 +73,7 @@ FOOTER = """ {NULL, NULL, NULL, NULL}\n};\n"""
# useful error reporting function
def error(msg):
"""Report an error on stderr."""
sys.stderr.write(msg+'\n')
sys.stderr.write(msg + '\n')
# read couples from stdin and build list
read_types = []
@ -91,14 +90,14 @@ for t in basic_types:
for v in t[1]:
found = filter(lambda x, y=v: x[0] == y, read_types)
if len(found) == 0:
error(v+': value not found')
error(v + ': value not found')
elif len(found) > 1:
error(v+': too many values')
error(v + ': too many values')
else:
found_types[k].append(int(found[0][1]))
# now outputs to stdout the right C-style definitions
stypes = "" ; sstruct = ""
stypes = sstruct = ""
for t in basic_types:
k = t[0]
s = str(found_types[k])
@ -108,7 +107,7 @@ for t in basic_types:
% (k, k, k))
for t in array_types:
kt = t[0]
ka = t[0]+'ARRAY'
ka = t[0] + 'ARRAY'
s = str(t[1])
s = '{' + s[1:-1] + ', 0}'
stypes = stypes + ('static long int typecast_%s_types[] = %s;\n' % (ka, s))

View File

@ -23,6 +23,7 @@ from collections import defaultdict
from BeautifulSoup import BeautifulSoup as BS
def main():
if len(sys.argv) != 2:
print >>sys.stderr, "usage: %s /path/to/errorcodes.py" % sys.argv[0]
@ -41,6 +42,7 @@ def main():
for line in generate_module_data(classes, errors):
print >>f, line
def read_base_file(filename):
rv = []
for line in open(filename):
@ -50,6 +52,7 @@ def read_base_file(filename):
raise ValueError("can't find the separator. Is this the right file?")
def parse_errors_txt(url):
classes = {}
errors = defaultdict(dict)
@ -84,6 +87,7 @@ def parse_errors_txt(url):
return classes, errors
def parse_errors_sgml(url):
page = BS(urllib2.urlopen(url))
table = page('table')[1]('tbody')[0]
@ -92,7 +96,7 @@ def parse_errors_sgml(url):
errors = defaultdict(dict)
for tr in table('tr'):
if tr.td.get('colspan'): # it's a class
if tr.td.get('colspan'): # it's a class
label = ' '.join(' '.join(tr(text=True)).split()) \
.replace(u'\u2014', '-').encode('ascii')
assert label.startswith('Class')
@ -100,7 +104,7 @@ def parse_errors_sgml(url):
assert len(class_) == 2
classes[class_] = label
else: # it's an error
else: # it's an error
errcode = tr.tt.string.encode("ascii")
assert len(errcode) == 5
@ -124,11 +128,12 @@ def parse_errors_sgml(url):
return classes, errors
errors_sgml_url = \
"http://www.postgresql.org/docs/%s/static/errcodes-appendix.html"
"http://www.postgresql.org/docs/%s/static/errcodes-appendix.html"
errors_txt_url = \
"http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob_plain;" \
"f=src/backend/utils/errcodes.txt;hb=REL%s_STABLE"
"http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob_plain;" \
"f=src/backend/utils/errcodes.txt;hb=REL%s_STABLE"
def fetch_errors(versions):
classes = {}
@ -148,14 +153,15 @@ def fetch_errors(versions):
return classes, errors
def generate_module_data(classes, errors):
yield ""
yield "# Error classes"
for clscode, clslabel in sorted(classes.items()):
err = clslabel.split(" - ")[1].split("(")[0] \
.strip().replace(" ", "_").replace('/', "_").upper()
.strip().replace(" ", "_").replace('/', "_").upper()
yield "CLASS_%s = %r" % (err, clscode)
for clscode, clslabel in sorted(classes.items()):
yield ""
yield "# %s" % clslabel
@ -163,7 +169,6 @@ def generate_module_data(classes, errors):
for errcode, errlabel in sorted(errors[clscode].items()):
yield "%s = %r" % (errlabel, errcode)
if __name__ == '__main__':
sys.exit(main())

View File

@ -25,6 +25,7 @@ import unittest
from pprint import pprint
from collections import defaultdict
def main():
opt = parse_args()
@ -58,6 +59,7 @@ def main():
return rv
def parse_args():
import optparse
@ -83,7 +85,7 @@ def dump(i, opt):
c[type(o)] += 1
pprint(
sorted(((v,str(k)) for k,v in c.items()), reverse=True),
sorted(((v, str(k)) for k, v in c.items()), reverse=True),
stream=open("debug-%02d.txt" % i, "w"))
if opt.objs:
@ -95,7 +97,7 @@ def dump(i, opt):
# TODO: very incomplete
if t is dict:
co.sort(key = lambda d: d.items())
co.sort(key=lambda d: d.items())
else:
co.sort()
@ -104,4 +106,3 @@ def dump(i, opt):
if __name__ == '__main__':
sys.exit(main())

View File

@ -25,34 +25,9 @@ UPDATEs. psycopg2 also provide full asynchronous operations and support
for coroutine libraries.
"""
# note: if you are changing the list of supported Python version please fix
# the docs in install.rst and the /features/ page on the website.
classifiers = """\
Development Status :: 5 - Production/Stable
Intended Audience :: Developers
License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)
License :: OSI Approved :: Zope Public License
Programming Language :: Python
Programming Language :: Python :: 2.6
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.1
Programming Language :: Python :: 3.2
Programming Language :: Python :: 3.3
Programming Language :: Python :: 3.4
Programming Language :: Python :: 3.5
Programming Language :: C
Programming Language :: SQL
Topic :: Database
Topic :: Database :: Front-Ends
Topic :: Software Development
Topic :: Software Development :: Libraries :: Python Modules
Operating System :: Microsoft :: Windows
Operating System :: Unix
"""
# Note: The setup.py must be compatible with both Python 2 and 3
import os
import sys
import re
@ -87,7 +62,34 @@ except ImportError:
PSYCOPG_VERSION = '2.7.dev0'
version_flags = ['dt', 'dec']
# note: if you are changing the list of supported Python version please fix
# the docs in install.rst and the /features/ page on the website.
classifiers = """\
Development Status :: 5 - Production/Stable
Intended Audience :: Developers
License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)
License :: OSI Approved :: Zope Public License
Programming Language :: Python
Programming Language :: Python :: 2.6
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.1
Programming Language :: Python :: 3.2
Programming Language :: Python :: 3.3
Programming Language :: Python :: 3.4
Programming Language :: Python :: 3.5
Programming Language :: C
Programming Language :: SQL
Topic :: Database
Topic :: Database :: Front-Ends
Topic :: Software Development
Topic :: Software Development :: Libraries :: Python Modules
Operating System :: Microsoft :: Windows
Operating System :: Unix
"""
version_flags = ['dt', 'dec']
PLATFORM_IS_WINDOWS = sys.platform.lower().startswith('win')
@ -208,7 +210,7 @@ or with the pg_config option in 'setup.cfg'.
# Support unicode paths, if this version of Python provides the
# necessary infrastructure:
if sys.version_info[0] < 3 \
and hasattr(sys, 'getfilesystemencoding'):
and hasattr(sys, 'getfilesystemencoding'):
pg_config_path = pg_config_path.encode(
sys.getfilesystemencoding())
@ -230,7 +232,7 @@ class psycopg_build_ext(build_ext):
('use-pydatetime', None,
"Use Python datatime objects for date and time representation."),
('pg-config=', None,
"The name of the pg_config binary and/or full path to find it"),
"The name of the pg_config binary and/or full path to find it"),
('have-ssl', None,
"Compile with OpenSSL built PostgreSQL libraries (Windows only)."),
('static-libpq', None,
@ -388,7 +390,7 @@ class psycopg_build_ext(build_ext):
if not getattr(self, 'link_objects', None):
self.link_objects = []
self.link_objects.append(
os.path.join(pg_config_helper.query("libdir"), "libpq.a"))
os.path.join(pg_config_helper.query("libdir"), "libpq.a"))
else:
self.libraries.append("pq")
@ -417,7 +419,7 @@ class psycopg_build_ext(build_ext):
else:
sys.stderr.write(
"Error: could not determine PostgreSQL version from '%s'"
% pgversion)
% pgversion)
sys.exit(1)
define_macros.append(("PG_VERSION_NUM", "%d%02d%02d" %
@ -445,6 +447,7 @@ class psycopg_build_ext(build_ext):
if hasattr(self, "finalize_" + sys.platform):
getattr(self, "finalize_" + sys.platform)()
def is_py_64():
# sys.maxint not available since Py 3.1;
# sys.maxsize not available before Py 2.6;
@ -511,7 +514,7 @@ parser.read('setup.cfg')
# Choose a datetime module
have_pydatetime = True
have_mxdatetime = False
use_pydatetime = int(parser.get('build_ext', 'use_pydatetime'))
use_pydatetime = int(parser.get('build_ext', 'use_pydatetime'))
# check for mx package
if parser.has_option('build_ext', 'mx_include_dir'):
@ -547,8 +550,8 @@ you probably need to install its companion -dev or -devel package."""
sys.exit(1)
# generate a nice version string to avoid confusion when users report bugs
version_flags.append('pq3') # no more a choice
version_flags.append('ext') # no more a choice
version_flags.append('pq3') # no more a choice
version_flags.append('ext') # no more a choice
if version_flags:
PSYCOPG_VERSION_EX = PSYCOPG_VERSION + " (%s)" % ' '.join(version_flags)
@ -580,8 +583,8 @@ for define in parser.get('build_ext', 'define').split(','):
# build the extension
sources = [ os.path.join('psycopg', x) for x in sources]
depends = [ os.path.join('psycopg', x) for x in depends]
sources = [os.path.join('psycopg', x) for x in sources]
depends = [os.path.join('psycopg', x) for x in depends]
ext.append(Extension("psycopg2._psycopg", sources,
define_macros=define_macros,

View File

@ -52,6 +52,7 @@ if sys.version_info[:2] >= (2, 5):
else:
test_with = None
def test_suite():
# If connection to test db fails, bail out early.
import psycopg2

View File

@ -33,6 +33,7 @@ import StringIO
from testutils import ConnectingTestCase
class PollableStub(object):
"""A 'pollable' wrapper allowing analysis of the `poll()` calls."""
def __init__(self, pollable):
@ -68,6 +69,7 @@ class AsyncTests(ConnectingTestCase):
def test_connection_setup(self):
cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor()
del cur, sync_cur
self.assert_(self.conn.async)
self.assert_(not self.sync_conn.async)
@ -77,7 +79,7 @@ class AsyncTests(ConnectingTestCase):
# check other properties to be found on the connection
self.assert_(self.conn.server_version)
self.assert_(self.conn.protocol_version in (2,3))
self.assert_(self.conn.protocol_version in (2, 3))
self.assert_(self.conn.encoding in psycopg2.extensions.encodings)
def test_async_named_cursor(self):
@ -108,6 +110,7 @@ class AsyncTests(ConnectingTestCase):
def test_async_after_async(self):
cur = self.conn.cursor()
cur2 = self.conn.cursor()
del cur2
cur.execute("insert into table1 values (1)")
@ -422,14 +425,14 @@ class AsyncTests(ConnectingTestCase):
def test_async_cursor_gone(self):
import gc
cur = self.conn.cursor()
cur.execute("select 42;");
cur.execute("select 42;")
del cur
gc.collect()
self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn)
# The connection is still usable
cur = self.conn.cursor()
cur.execute("select 42;");
cur.execute("select 42;")
self.wait(self.conn)
self.assertEqual(cur.fetchone(), (42,))
@ -449,4 +452,3 @@ def test_suite():
if __name__ == "__main__":
unittest.main()

View File

@ -26,15 +26,17 @@ import psycopg2
import time
import unittest
class DateTimeAllocationBugTestCase(unittest.TestCase):
def test_date_time_allocation_bug(self):
d1 = psycopg2.Date(2002,12,25)
d2 = psycopg2.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
t1 = psycopg2.Time(13,45,30)
t2 = psycopg2.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
t1 = psycopg2.Timestamp(2002,12,25,13,45,30)
d1 = psycopg2.Date(2002, 12, 25)
d2 = psycopg2.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
t1 = psycopg2.Time(13, 45, 30)
t2 = psycopg2.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
t1 = psycopg2.Timestamp(2002, 12, 25, 13, 45, 30)
t2 = psycopg2.TimestampFromTicks(
time.mktime((2002,12,25,13,45,30,0,0,0)))
time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
del d1, d2, t1, t2
def test_suite():

View File

@ -29,6 +29,7 @@ import gc
from testutils import ConnectingTestCase, skip_if_no_uuid
class StolenReferenceTestCase(ConnectingTestCase):
@skip_if_no_uuid
def test_stolen_reference_bug(self):
@ -41,8 +42,10 @@ class StolenReferenceTestCase(ConnectingTestCase):
curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid")
curs.fetchone()
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

View File

@ -32,6 +32,7 @@ from psycopg2 import extras
from testconfig import dsn
from testutils import unittest, ConnectingTestCase, skip_before_postgres
class CancelTests(ConnectingTestCase):
def setUp(self):
@ -71,6 +72,7 @@ class CancelTests(ConnectingTestCase):
except Exception, e:
errors.append(e)
raise
del cur
thread1 = threading.Thread(target=neverending, args=(self.conn, ))
# wait a bit to make sure that the other thread is already in

View File

@ -27,17 +27,16 @@ import sys
import time
import threading
from operator import attrgetter
from StringIO import StringIO
import psycopg2
import psycopg2.errorcodes
import psycopg2.extensions
ext = psycopg2.extensions
from psycopg2 import extensions as ext
from testutils import (
unittest, decorate_all_tests, skip_if_no_superuser,
skip_before_postgres, skip_after_postgres, skip_before_libpq,
ConnectingTestCase, skip_if_tpc_disabled, skip_if_windows)
from testutils import unittest, decorate_all_tests, skip_if_no_superuser
from testutils import skip_before_postgres, skip_after_postgres, skip_before_libpq
from testutils import ConnectingTestCase, skip_if_tpc_disabled
from testutils import skip_if_windows
from testconfig import dsn, dbname
@ -112,8 +111,14 @@ class ConnectionTests(ConnectingTestCase):
cur = conn.cursor()
if self.conn.server_version >= 90300:
cur.execute("set client_min_messages=debug1")
cur.execute("create temp table table1 (id serial); create temp table table2 (id serial);")
cur.execute("create temp table table3 (id serial); create temp table table4 (id serial);")
cur.execute("""
create temp table table1 (id serial);
create temp table table2 (id serial);
""")
cur.execute("""
create temp table table3 (id serial);
create temp table table4 (id serial);
""")
self.assertEqual(4, len(conn.notices))
self.assert_('table1' in conn.notices[0])
self.assert_('table2' in conn.notices[1])
@ -126,7 +131,8 @@ class ConnectionTests(ConnectingTestCase):
if self.conn.server_version >= 90300:
cur.execute("set client_min_messages=debug1")
for i in range(0, 100, 10):
sql = " ".join(["create temp table table%d (id serial);" % j for j in range(i, i + 10)])
sql = " ".join(["create temp table table%d (id serial);" % j
for j in range(i, i + 10)])
cur.execute(sql)
self.assertEqual(50, len(conn.notices))
@ -141,8 +147,13 @@ class ConnectionTests(ConnectingTestCase):
if self.conn.server_version >= 90300:
cur.execute("set client_min_messages=debug1")
cur.execute("create temp table table1 (id serial); create temp table table2 (id serial);")
cur.execute("create temp table table3 (id serial); create temp table table4 (id serial);")
cur.execute("""
create temp table table1 (id serial);
create temp table table2 (id serial);
""")
cur.execute("""
create temp table table3 (id serial);
create temp table table4 (id serial);""")
self.assertEqual(len(conn.notices), 4)
self.assert_('table1' in conn.notices.popleft())
self.assert_('table2' in conn.notices.popleft())
@ -152,7 +163,8 @@ class ConnectionTests(ConnectingTestCase):
# not limited, but no error
for i in range(0, 100, 10):
sql = " ".join(["create temp table table2_%d (id serial);" % j for j in range(i, i + 10)])
sql = " ".join(["create temp table table2_%d (id serial);" % j
for j in range(i, i + 10)])
cur.execute(sql)
self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]),
@ -315,16 +327,18 @@ class ParseDsnTestCase(ConnectingTestCase):
def test_parse_dsn(self):
from psycopg2 import ProgrammingError
self.assertEqual(ext.parse_dsn('dbname=test user=tester password=secret'),
dict(user='tester', password='secret', dbname='test'),
"simple DSN parsed")
self.assertEqual(
ext.parse_dsn('dbname=test user=tester password=secret'),
dict(user='tester', password='secret', dbname='test'),
"simple DSN parsed")
self.assertRaises(ProgrammingError, ext.parse_dsn,
"dbname=test 2 user=tester password=secret")
self.assertEqual(ext.parse_dsn("dbname='test 2' user=tester password=secret"),
dict(user='tester', password='secret', dbname='test 2'),
"DSN with quoting parsed")
self.assertEqual(
ext.parse_dsn("dbname='test 2' user=tester password=secret"),
dict(user='tester', password='secret', dbname='test 2'),
"DSN with quoting parsed")
# Can't really use assertRaisesRegexp() here since we need to
# make sure that secret is *not* exposed in the error messgage
@ -485,7 +499,8 @@ class IsolationLevelsTestCase(ConnectingTestCase):
levels = [
(None, psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT),
('read uncommitted', psycopg2.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED),
('read uncommitted',
psycopg2.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED),
('read committed', psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED),
('repeatable read', psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ),
('serializable', psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE),

View File

@ -39,7 +39,8 @@ from testconfig import dsn
if sys.version_info[0] < 3:
_base = object
else:
from io import TextIOBase as _base
from io import TextIOBase as _base
class MinimalRead(_base):
"""A file wrapper exposing the minimal interface to copy from."""
@ -52,6 +53,7 @@ class MinimalRead(_base):
def readline(self):
return self.f.readline()
class MinimalWrite(_base):
"""A file wrapper exposing the minimal interface to copy to."""
def __init__(self, f):
@ -78,7 +80,7 @@ class CopyTests(ConnectingTestCase):
def test_copy_from(self):
curs = self.conn.cursor()
try:
self._copy_from(curs, nrecs=1024, srec=10*1024, copykw={})
self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
finally:
curs.close()
@ -86,8 +88,8 @@ class CopyTests(ConnectingTestCase):
# Trying to trigger a "would block" error
curs = self.conn.cursor()
try:
self._copy_from(curs, nrecs=10*1024, srec=10*1024,
copykw={'size': 20*1024*1024})
self._copy_from(curs, nrecs=10 * 1024, srec=10 * 1024,
copykw={'size': 20 * 1024 * 1024})
finally:
curs.close()
@ -110,6 +112,7 @@ class CopyTests(ConnectingTestCase):
f.write("%s\n" % (i,))
f.seek(0)
def cols():
raise ZeroDivisionError()
yield 'id'
@ -120,8 +123,8 @@ class CopyTests(ConnectingTestCase):
def test_copy_to(self):
curs = self.conn.cursor()
try:
self._copy_from(curs, nrecs=1024, srec=10*1024, copykw={})
self._copy_to(curs, srec=10*1024)
self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
self._copy_to(curs, srec=10 * 1024)
finally:
curs.close()
@ -209,9 +212,11 @@ class CopyTests(ConnectingTestCase):
exp_size = 123
# hack here to leave file as is, only check size when reading
real_read = f.read
def read(_size, f=f, exp_size=exp_size):
self.assertEqual(_size, exp_size)
return real_read(_size)
f.read = read
curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size)
curs.execute("select data from tcopy;")
@ -221,7 +226,7 @@ class CopyTests(ConnectingTestCase):
f = StringIO()
for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)):
l = c * srec
f.write("%s\t%s\n" % (i,l))
f.write("%s\t%s\n" % (i, l))
f.seek(0)
curs.copy_from(MinimalRead(f), "tcopy", **copykw)
@ -258,24 +263,24 @@ class CopyTests(ConnectingTestCase):
curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f)
def test_copy_no_column_limit(self):
cols = [ "c%050d" % i for i in range(200) ]
cols = ["c%050d" % i for i in range(200)]
curs = self.conn.cursor()
curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join(
[ "%s int" % c for c in cols]))
["%s int" % c for c in cols]))
curs.execute("INSERT INTO manycols DEFAULT VALUES")
f = StringIO()
curs.copy_to(f, "manycols", columns = cols)
curs.copy_to(f, "manycols", columns=cols)
f.seek(0)
self.assertEqual(f.read().split(), ['\\N'] * len(cols))
f.seek(0)
curs.copy_from(f, "manycols", columns = cols)
curs.copy_from(f, "manycols", columns=cols)
curs.execute("select count(*) from manycols;")
self.assertEqual(curs.fetchone()[0], 2)
@skip_before_postgres(8, 2) # they don't send the count
@skip_before_postgres(8, 2) # they don't send the count
def test_copy_rowcount(self):
curs = self.conn.cursor()
@ -316,7 +321,7 @@ try:
except psycopg2.ProgrammingError:
pass
conn.close()
""" % { 'dsn': dsn,})
""" % {'dsn': dsn})
proc = Popen([sys.executable, '-c', script_to_py3(script)])
proc.communicate()
@ -334,7 +339,7 @@ try:
except psycopg2.ProgrammingError:
pass
conn.close()
""" % { 'dsn': dsn,})
""" % {'dsn': dsn})
proc = Popen([sys.executable, '-c', script_to_py3(script)], stdout=PIPE)
proc.communicate()
@ -343,10 +348,10 @@ conn.close()
def test_copy_from_propagate_error(self):
class BrokenRead(_base):
def read(self, size):
return 1/0
return 1 / 0
def readline(self):
return 1/0
return 1 / 0
curs = self.conn.cursor()
# It seems we cannot do this, but now at least we propagate the error
@ -360,7 +365,7 @@ conn.close()
def test_copy_to_propagate_error(self):
class BrokenWrite(_base):
def write(self, data):
return 1/0
return 1 / 0
curs = self.conn.cursor()
curs.execute("insert into tcopy values (10, 'hi')")

View File

@ -29,6 +29,7 @@ import psycopg2.extensions
from testutils import unittest, ConnectingTestCase, skip_before_postgres
from testutils import skip_if_no_namedtuple, skip_if_no_getrefcount
class CursorTests(ConnectingTestCase):
def test_close_idempotent(self):
@ -47,8 +48,10 @@ class CursorTests(ConnectingTestCase):
conn = self.conn
cur = conn.cursor()
cur.execute("create temp table test_exc (data int);")
def buggygen():
yield 1//0
yield 1 // 0
self.assertRaises(ZeroDivisionError,
cur.executemany, "insert into test_exc values (%s)", buggygen())
cur.close()
@ -102,8 +105,7 @@ class CursorTests(ConnectingTestCase):
# issue #81: reference leak when a parameter value is referenced
# more than once from a dict.
cur = self.conn.cursor()
i = lambda x: x
foo = i('foo') * 10
foo = (lambda x: x)('foo') * 10
import sys
nref1 = sys.getrefcount(foo)
cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo})
@ -135,7 +137,7 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45'))
from datetime import date
self.assertEqual(date(2011,1,2), curs.cast(1082, '2011-01-02'))
self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02'))
self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown
def test_cast_specificity(self):
@ -158,7 +160,8 @@ class CursorTests(ConnectingTestCase):
curs = self.conn.cursor()
w = ref(curs)
del curs
import gc; gc.collect()
import gc
gc.collect()
self.assert_(w() is None)
def test_null_name(self):
@ -168,7 +171,7 @@ class CursorTests(ConnectingTestCase):
def test_invalid_name(self):
curs = self.conn.cursor()
curs.execute("create temp table invname (data int);")
for i in (10,20,30):
for i in (10, 20, 30):
curs.execute("insert into invname values (%s)", (i,))
curs.close()
@ -193,16 +196,16 @@ class CursorTests(ConnectingTestCase):
self._create_withhold_table()
curs = self.conn.cursor("W")
self.assertEqual(curs.withhold, False);
self.assertEqual(curs.withhold, False)
curs.withhold = True
self.assertEqual(curs.withhold, True);
self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data")
self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs.close()
curs = self.conn.cursor("W", withhold=True)
self.assertEqual(curs.withhold, True);
self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data")
self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
@ -264,18 +267,18 @@ class CursorTests(ConnectingTestCase):
curs = self.conn.cursor()
curs.execute("create table scrollable (data int)")
curs.executemany("insert into scrollable values (%s)",
[ (i,) for i in range(100) ])
[(i,) for i in range(100)])
curs.close()
for t in range(2):
if not t:
curs = self.conn.cursor("S")
self.assertEqual(curs.scrollable, None);
self.assertEqual(curs.scrollable, None)
curs.scrollable = True
else:
curs = self.conn.cursor("S", scrollable=True)
self.assertEqual(curs.scrollable, True);
self.assertEqual(curs.scrollable, True)
curs.itersize = 10
# complex enough to make postgres cursors declare without
@ -303,7 +306,7 @@ class CursorTests(ConnectingTestCase):
curs = self.conn.cursor()
curs.execute("create table scrollable (data int)")
curs.executemany("insert into scrollable values (%s)",
[ (i,) for i in range(100) ])
[(i,) for i in range(100)])
curs.close()
curs = self.conn.cursor("S") # default scrollability
@ -340,18 +343,18 @@ class CursorTests(ConnectingTestCase):
def test_iter_named_cursor_default_itersize(self):
curs = self.conn.cursor('tmp')
curs.execute('select generate_series(1,50)')
rv = [ (r[0], curs.rownumber) for r in curs ]
rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in one gulp
self.assertEqual(rv, [(i,i) for i in range(1,51)])
self.assertEqual(rv, [(i, i) for i in range(1, 51)])
@skip_before_postgres(8, 0)
def test_iter_named_cursor_itersize(self):
curs = self.conn.cursor('tmp')
curs.itersize = 30
curs.execute('select generate_series(1,50)')
rv = [ (r[0], curs.rownumber) for r in curs ]
rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in two gulps
self.assertEqual(rv, [(i,((i - 1) % 30) + 1) for i in range(1,51)])
self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)])
@skip_before_postgres(8, 0)
def test_iter_named_cursor_rownumber(self):

View File

@ -27,6 +27,7 @@ import psycopg2
from psycopg2.tz import FixedOffsetTimezone, ZERO
from testutils import unittest, ConnectingTestCase, skip_before_postgres
class CommonDatetimeTestsMixin:
def execute(self, *args):
@ -144,10 +145,10 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
# The Python datetime module does not support time zone
# offsets that are not a whole number of minutes.
# We round the offset to the nearest minute.
self.check_time_tz("+01:15:00", 60 * (60 + 15))
self.check_time_tz("+01:15:29", 60 * (60 + 15))
self.check_time_tz("+01:15:30", 60 * (60 + 16))
self.check_time_tz("+01:15:59", 60 * (60 + 16))
self.check_time_tz("+01:15:00", 60 * (60 + 15))
self.check_time_tz("+01:15:29", 60 * (60 + 15))
self.check_time_tz("+01:15:30", 60 * (60 + 16))
self.check_time_tz("+01:15:59", 60 * (60 + 16))
self.check_time_tz("-01:15:00", -60 * (60 + 15))
self.check_time_tz("-01:15:29", -60 * (60 + 15))
self.check_time_tz("-01:15:30", -60 * (60 + 16))
@ -180,10 +181,10 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
# The Python datetime module does not support time zone
# offsets that are not a whole number of minutes.
# We round the offset to the nearest minute.
self.check_datetime_tz("+01:15:00", 60 * (60 + 15))
self.check_datetime_tz("+01:15:29", 60 * (60 + 15))
self.check_datetime_tz("+01:15:30", 60 * (60 + 16))
self.check_datetime_tz("+01:15:59", 60 * (60 + 16))
self.check_datetime_tz("+01:15:00", 60 * (60 + 15))
self.check_datetime_tz("+01:15:29", 60 * (60 + 15))
self.check_datetime_tz("+01:15:30", 60 * (60 + 16))
self.check_datetime_tz("+01:15:59", 60 * (60 + 16))
self.check_datetime_tz("-01:15:00", -60 * (60 + 15))
self.check_datetime_tz("-01:15:29", -60 * (60 + 15))
self.check_datetime_tz("-01:15:30", -60 * (60 + 16))
@ -269,32 +270,32 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date(self):
from datetime import date
self._test_type_roundtrip(date(2010,5,3))
self._test_type_roundtrip(date(2010, 5, 3))
def test_type_roundtrip_datetime(self):
from datetime import datetime
dt = self._test_type_roundtrip(datetime(2010,5,3,10,20,30))
dt = self._test_type_roundtrip(datetime(2010, 5, 3, 10, 20, 30))
self.assertEqual(None, dt.tzinfo)
def test_type_roundtrip_datetimetz(self):
from datetime import datetime
import psycopg2.tz
tz = psycopg2.tz.FixedOffsetTimezone(8*60)
dt1 = datetime(2010,5,3,10,20,30, tzinfo=tz)
tz = psycopg2.tz.FixedOffsetTimezone(8 * 60)
dt1 = datetime(2010, 5, 3, 10, 20, 30, tzinfo=tz)
dt2 = self._test_type_roundtrip(dt1)
self.assertNotEqual(None, dt2.tzinfo)
self.assertEqual(dt1, dt2)
def test_type_roundtrip_time(self):
from datetime import time
tm = self._test_type_roundtrip(time(10,20,30))
tm = self._test_type_roundtrip(time(10, 20, 30))
self.assertEqual(None, tm.tzinfo)
def test_type_roundtrip_timetz(self):
from datetime import time
import psycopg2.tz
tz = psycopg2.tz.FixedOffsetTimezone(8*60)
tm1 = time(10,20,30, tzinfo=tz)
tz = psycopg2.tz.FixedOffsetTimezone(8 * 60)
tm1 = time(10, 20, 30, tzinfo=tz)
tm2 = self._test_type_roundtrip(tm1)
self.assertNotEqual(None, tm2.tzinfo)
self.assertEqual(tm1, tm2)
@ -305,15 +306,15 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date_array(self):
from datetime import date
self._test_type_roundtrip_array(date(2010,5,3))
self._test_type_roundtrip_array(date(2010, 5, 3))
def test_type_roundtrip_datetime_array(self):
from datetime import datetime
self._test_type_roundtrip_array(datetime(2010,5,3,10,20,30))
self._test_type_roundtrip_array(datetime(2010, 5, 3, 10, 20, 30))
def test_type_roundtrip_time_array(self):
from datetime import time
self._test_type_roundtrip_array(time(10,20,30))
self._test_type_roundtrip_array(time(10, 20, 30))
def test_type_roundtrip_interval_array(self):
from datetime import timedelta
@ -355,8 +356,10 @@ class mxDateTimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
psycopg2.extensions.register_type(self.INTERVAL, self.conn)
psycopg2.extensions.register_type(psycopg2.extensions.MXDATEARRAY, self.conn)
psycopg2.extensions.register_type(psycopg2.extensions.MXTIMEARRAY, self.conn)
psycopg2.extensions.register_type(psycopg2.extensions.MXDATETIMEARRAY, self.conn)
psycopg2.extensions.register_type(psycopg2.extensions.MXINTERVALARRAY, self.conn)
psycopg2.extensions.register_type(
psycopg2.extensions.MXDATETIMEARRAY, self.conn)
psycopg2.extensions.register_type(
psycopg2.extensions.MXINTERVALARRAY, self.conn)
def tearDown(self):
self.conn.close()
@ -479,15 +482,15 @@ class mxDateTimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date(self):
from mx.DateTime import Date
self._test_type_roundtrip(Date(2010,5,3))
self._test_type_roundtrip(Date(2010, 5, 3))
def test_type_roundtrip_datetime(self):
from mx.DateTime import DateTime
self._test_type_roundtrip(DateTime(2010,5,3,10,20,30))
self._test_type_roundtrip(DateTime(2010, 5, 3, 10, 20, 30))
def test_type_roundtrip_time(self):
from mx.DateTime import Time
self._test_type_roundtrip(Time(10,20,30))
self._test_type_roundtrip(Time(10, 20, 30))
def test_type_roundtrip_interval(self):
from mx.DateTime import DateTimeDeltaFrom
@ -495,15 +498,15 @@ class mxDateTimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date_array(self):
from mx.DateTime import Date
self._test_type_roundtrip_array(Date(2010,5,3))
self._test_type_roundtrip_array(Date(2010, 5, 3))
def test_type_roundtrip_datetime_array(self):
from mx.DateTime import DateTime
self._test_type_roundtrip_array(DateTime(2010,5,3,10,20,30))
self._test_type_roundtrip_array(DateTime(2010, 5, 3, 10, 20, 30))
def test_type_roundtrip_time_array(self):
from mx.DateTime import Time
self._test_type_roundtrip_array(Time(10,20,30))
self._test_type_roundtrip_array(Time(10, 20, 30))
def test_type_roundtrip_interval_array(self):
from mx.DateTime import DateTimeDeltaFrom
@ -549,22 +552,30 @@ class FixedOffsetTimezoneTests(unittest.TestCase):
def test_repr_with_positive_offset(self):
tzinfo = FixedOffsetTimezone(5 * 60)
self.assertEqual(repr(tzinfo), "psycopg2.tz.FixedOffsetTimezone(offset=300, name=None)")
self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=300, name=None)")
def test_repr_with_negative_offset(self):
tzinfo = FixedOffsetTimezone(-5 * 60)
self.assertEqual(repr(tzinfo), "psycopg2.tz.FixedOffsetTimezone(offset=-300, name=None)")
self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=-300, name=None)")
def test_repr_with_name(self):
tzinfo = FixedOffsetTimezone(name="FOO")
self.assertEqual(repr(tzinfo), "psycopg2.tz.FixedOffsetTimezone(offset=0, name='FOO')")
self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=0, name='FOO')")
def test_instance_caching(self):
self.assert_(FixedOffsetTimezone(name="FOO") is FixedOffsetTimezone(name="FOO"))
self.assert_(FixedOffsetTimezone(7 * 60) is FixedOffsetTimezone(7 * 60))
self.assert_(FixedOffsetTimezone(-9 * 60, 'FOO') is FixedOffsetTimezone(-9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(9 * 60) is not FixedOffsetTimezone(9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(name='FOO') is not FixedOffsetTimezone(9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(name="FOO")
is FixedOffsetTimezone(name="FOO"))
self.assert_(FixedOffsetTimezone(7 * 60)
is FixedOffsetTimezone(7 * 60))
self.assert_(FixedOffsetTimezone(-9 * 60, 'FOO')
is FixedOffsetTimezone(-9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(9 * 60)
is not FixedOffsetTimezone(9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(name='FOO')
is not FixedOffsetTimezone(9 * 60, 'FOO'))
def test_pickle(self):
# ticket #135

View File

@ -32,6 +32,7 @@ except NameError:
from threading import Thread
from psycopg2 import errorcodes
class ErrocodeTests(ConnectingTestCase):
def test_lookup_threadsafe(self):
@ -39,6 +40,7 @@ class ErrocodeTests(ConnectingTestCase):
MAX_CYCLES = 2000
errs = []
def f(pg_code='40001'):
try:
errorcodes.lookup(pg_code)

View File

@ -39,7 +39,8 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
self.assertEqual(cur.name, None)
# overridable
cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.NamedTupleCursor)
cur = self.conn.cursor('foo',
cursor_factory=psycopg2.extras.NamedTupleCursor)
self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor))
@ -80,7 +81,6 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self.failUnless(row[0] == 'bar')
return row
def testDictCursorWithPlainCursorRealFetchOne(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchone())
@ -110,7 +110,6 @@ class ExtrasDictCursorTests(ConnectingTestCase):
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def testDictCursorWithNamedCursorFetchOne(self):
self._testWithNamedCursor(lambda curs: curs.fetchone())
@ -146,7 +145,6 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar')
def testDictCursorRealWithNamedCursorFetchOne(self):
self._testWithNamedCursorReal(lambda curs: curs.fetchone())
@ -176,12 +174,12 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self._testIterRowNumber(curs)
def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.RealDictCursor)
curs = self.conn.cursor('aname',
cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
def _testNamedCursorNotGreedy(self, curs):
curs.itersize = 2
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""")
@ -235,7 +233,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
from psycopg2.extras import NamedTupleConnection
try:
from collections import namedtuple
from collections import namedtuple # noqa
except ImportError:
return
@ -346,7 +344,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
def test_error_message(self):
try:
from collections import namedtuple
from collections import namedtuple # noqa
except ImportError:
# an import error somewhere
from psycopg2.extras import NamedTupleConnection
@ -390,6 +388,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
from psycopg2.extras import NamedTupleCursor
f_orig = NamedTupleCursor._make_nt
calls = [0]
def f_patched(self_):
calls[0] += 1
return f_orig(self_)

View File

@ -29,6 +29,7 @@ import psycopg2.extras
from testutils import ConnectingTestCase
class ConnectionStub(object):
"""A `connection` wrapper allowing analysis of the `poll()` calls."""
def __init__(self, conn):
@ -43,6 +44,7 @@ class ConnectionStub(object):
self.polls.append(rv)
return rv
class GreenTestCase(ConnectingTestCase):
def setUp(self):
self._cb = psycopg2.extensions.get_wait_callback()
@ -89,7 +91,7 @@ class GreenTestCase(ConnectingTestCase):
curs.fetchone()
# now try to do something that will fail in the callback
psycopg2.extensions.set_wait_callback(lambda conn: 1//0)
psycopg2.extensions.set_wait_callback(lambda conn: 1 // 0)
self.assertRaises(ZeroDivisionError, curs.execute, "select 2")
self.assert_(conn.closed)

View File

@ -32,6 +32,7 @@ import psycopg2.extensions
from testutils import unittest, decorate_all_tests, skip_if_tpc_disabled
from testutils import ConnectingTestCase, skip_if_green
def skip_if_no_lo(f):
@wraps(f)
def skip_if_no_lo_(self):
@ -158,7 +159,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read(self):
lo = self.conn.lobject()
length = lo.write(b"some data")
lo.write(b"some data")
lo.close()
lo = self.conn.lobject(lo.oid)
@ -169,7 +170,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read_binary(self):
lo = self.conn.lobject()
length = lo.write(b"some data")
lo.write(b"some data")
lo.close()
lo = self.conn.lobject(lo.oid, "rb")
@ -181,7 +182,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read_text(self):
lo = self.conn.lobject()
snowman = u"\u2603"
length = lo.write(u"some data " + snowman)
lo.write(u"some data " + snowman)
lo.close()
lo = self.conn.lobject(lo.oid, "rt")
@ -193,7 +194,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read_large(self):
lo = self.conn.lobject()
data = "data" * 1000000
length = lo.write("some" + data)
lo.write("some" + data)
lo.close()
lo = self.conn.lobject(lo.oid)
@ -399,6 +400,7 @@ def skip_if_no_truncate(f):
return skip_if_no_truncate_
class LargeObjectTruncateTests(LargeObjectTestCase):
def test_truncate(self):
lo = self.conn.lobject()
@ -450,15 +452,19 @@ def _has_lo64(conn):
return (True, "this server and build support the lo64 API")
def skip_if_no_lo64(f):
@wraps(f)
def skip_if_no_lo64_(self):
lo64, msg = _has_lo64(self.conn)
if not lo64: return self.skipTest(msg)
else: return f(self)
if not lo64:
return self.skipTest(msg)
else:
return f(self)
return skip_if_no_lo64_
class LargeObject64Tests(LargeObjectTestCase):
def test_seek_tell_truncate_greater_than_2gb(self):
lo = self.conn.lobject()
@ -477,11 +483,14 @@ def skip_if_lo64(f):
@wraps(f)
def skip_if_lo64_(self):
lo64, msg = _has_lo64(self.conn)
if lo64: return self.skipTest(msg)
else: return f(self)
if lo64:
return self.skipTest(msg)
else:
return f(self)
return skip_if_lo64_
class LargeObjectNot64Tests(LargeObjectTestCase):
def test_seek_larger_than_2gb(self):
lo = self.conn.lobject()

View File

@ -67,8 +67,8 @@ curs.execute("NOTIFY " %(name)r %(payload)r)
curs.close()
conn.close()
""" % {
'module': psycopg2.__name__,
'dsn': dsn, 'sec': sec, 'name': name, 'payload': payload})
'module': psycopg2.__name__,
'dsn': dsn, 'sec': sec, 'name': name, 'payload': payload})
return Popen([sys.executable, '-c', script_to_py3(script)], stdout=PIPE)
@ -79,7 +79,7 @@ conn.close()
proc = self.notify('foo', 1)
t0 = time.time()
ready = select.select([self.conn], [], [], 5)
select.select([self.conn], [], [], 5)
t1 = time.time()
self.assert_(0.99 < t1 - t0 < 4, t1 - t0)
@ -107,7 +107,7 @@ conn.close()
names = dict.fromkeys(['foo', 'bar', 'baz'])
for (pid, name) in self.conn.notifies:
self.assertEqual(pids[name], pid)
names.pop(name) # raise if name found twice
names.pop(name) # raise if name found twice
def test_notifies_received_on_execute(self):
self.autocommit(self.conn)
@ -217,6 +217,6 @@ conn.close()
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

View File

@ -30,12 +30,13 @@ import psycopg2
from testconfig import dsn
class Psycopg2Tests(dbapi20.DatabaseAPI20Test):
driver = psycopg2
connect_args = ()
connect_kw_args = {'dsn': dsn}
lower_func = 'lower' # For stored procedure test
lower_func = 'lower' # For stored procedure test
def test_setoutputsize(self):
# psycopg2's setoutputsize() is a no-op

View File

@ -24,8 +24,8 @@
import psycopg2
import psycopg2.extensions
from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection
from psycopg2.extras import StopReplication
from psycopg2.extras import (
PhysicalReplicationConnection, LogicalReplicationConnection, StopReplication)
import testconfig
from testutils import unittest
@ -70,14 +70,16 @@ class ReplicationTestCase(ConnectingTestCase):
# generate some events for our replication stream
def make_replication_events(self):
conn = self.connect()
if conn is None: return
if conn is None:
return
cur = conn.cursor()
try:
cur.execute("DROP TABLE dummy1")
except psycopg2.ProgrammingError:
conn.rollback()
cur.execute("CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
cur.execute(
"CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
conn.commit()
@ -85,7 +87,8 @@ class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 0)
def test_physical_replication_connection(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return
if conn is None:
return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@ -93,41 +96,49 @@ class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4)
def test_logical_replication_connection(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: return
if conn is None:
return
cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM")
cur.fetchall()
@skip_before_postgres(9, 4) # slots require 9.4
@skip_before_postgres(9, 4) # slots require 9.4
def test_create_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur)
self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur)
self.assertRaises(
psycopg2.ProgrammingError, self.create_replication_slot, cur)
@skip_before_postgres(9, 4) # slots require 9.4
@skip_before_postgres(9, 4) # slots require 9.4
def test_start_on_missing_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return
if conn is None:
return
cur = conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, cur.start_replication, self.slot)
self.assertRaises(psycopg2.ProgrammingError,
cur.start_replication, self.slot)
self.create_replication_slot(cur)
cur.start_replication(self.slot)
@skip_before_postgres(9, 4) # slots require 9.4
@skip_before_postgres(9, 4) # slots require 9.4
def test_start_and_recover_from_error(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: return
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding')
# try with invalid options
cur.start_replication(slot_name=self.slot, options={'invalid_param': 'value'})
cur.start_replication(
slot_name=self.slot, options={'invalid_param': 'value'})
def consume(msg):
pass
# we don't see the error from the server before we try to read the data
@ -136,10 +147,11 @@ class ReplicationTest(ReplicationTestCase):
# try with correct command
cur.start_replication(slot_name=self.slot)
@skip_before_postgres(9, 4) # slots require 9.4
@skip_before_postgres(9, 4) # slots require 9.4
def test_stop_replication(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: return
if conn is None:
return
cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding')
@ -147,16 +159,19 @@ class ReplicationTest(ReplicationTestCase):
self.make_replication_events()
cur.start_replication(self.slot)
def consume(msg):
raise StopReplication()
self.assertRaises(StopReplication, cur.consume_stream, consume)
class AsyncReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) # slots require 9.4
@skip_before_postgres(9, 4) # slots require 9.4
def test_async_replication(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1)
if conn is None: return
conn = self.repl_connect(
connection_factory=LogicalReplicationConnection, async=1)
if conn is None:
return
self.wait(conn)
cur = conn.cursor()
@ -169,9 +184,10 @@ class AsyncReplicationTest(ReplicationTestCase):
self.make_replication_events()
self.msg_count = 0
def consume(msg):
# just check the methods
log = "%s: %s" % (cur.io_timestamp, repr(msg))
"%s: %s" % (cur.io_timestamp, repr(msg))
self.msg_count += 1
if self.msg_count > 3:
@ -193,8 +209,10 @@ class AsyncReplicationTest(ReplicationTestCase):
select([cur], [], [])
self.assertRaises(StopReplication, process_stream)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

View File

@ -29,6 +29,7 @@ import psycopg2
from psycopg2.extensions import (
ISOLATION_LEVEL_SERIALIZABLE, STATUS_BEGIN, STATUS_READY)
class TransactionTests(ConnectingTestCase):
def setUp(self):
@ -147,6 +148,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
self.thread1_error = exc
step1.set()
conn.close()
def task2():
try:
conn = self.connect()
@ -174,7 +176,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
self.assertFalse(self.thread1_error and self.thread2_error)
error = self.thread1_error or self.thread2_error
self.assertTrue(isinstance(
error, psycopg2.extensions.TransactionRollbackError))
error, psycopg2.extensions.TransactionRollbackError))
def test_serialisation_failure(self):
self.thread1_error = self.thread2_error = None
@ -195,6 +197,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
self.thread1_error = exc
step1.set()
conn.close()
def task2():
try:
conn = self.connect()
@ -221,7 +224,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
self.assertFalse(self.thread1_error and self.thread2_error)
error = self.thread1_error or self.thread2_error
self.assertTrue(isinstance(
error, psycopg2.extensions.TransactionRollbackError))
error, psycopg2.extensions.TransactionRollbackError))
class QueryCancellationTests(ConnectingTestCase):

View File

@ -68,13 +68,16 @@ class TypesBasicTests(ConnectingTestCase):
"wrong decimal quoting: " + str(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("NaN"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s))
self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("infinity"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s))
self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("-infinity"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s))
self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
def testFloatNan(self):
try:
@ -141,8 +144,8 @@ class TypesBasicTests(ConnectingTestCase):
self.assertEqual(s, buf2.tobytes())
def testArray(self):
s = self.execute("SELECT %s AS foo", ([[1,2],[3,4]],))
self.failUnlessEqual(s, [[1,2],[3,4]])
s = self.execute("SELECT %s AS foo", ([[1, 2], [3, 4]],))
self.failUnlessEqual(s, [[1, 2], [3, 4]])
s = self.execute("SELECT %s AS foo", (['one', 'two', 'three'],))
self.failUnlessEqual(s, ['one', 'two', 'three'])
@ -150,9 +153,12 @@ class TypesBasicTests(ConnectingTestCase):
# ticket #42
import datetime
curs = self.conn.cursor()
curs.execute("create table array_test (id integer, col timestamp without time zone[])")
curs.execute(
"create table array_test "
"(id integer, col timestamp without time zone[])")
curs.execute("insert into array_test values (%s, %s)", (1, [datetime.date(2011,2,14)]))
curs.execute("insert into array_test values (%s, %s)",
(1, [datetime.date(2011, 2, 14)]))
curs.execute("select col from array_test where id = 1")
self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)])
@ -208,9 +214,9 @@ class TypesBasicTests(ConnectingTestCase):
curs.execute("insert into na (texta) values (%s)", ([None],))
curs.execute("insert into na (texta) values (%s)", (['a', None],))
curs.execute("insert into na (texta) values (%s)", ([None, None],))
curs.execute("insert into na (inta) values (%s)", ([None],))
curs.execute("insert into na (inta) values (%s)", ([42, None],))
curs.execute("insert into na (inta) values (%s)", ([None, None],))
curs.execute("insert into na (inta) values (%s)", ([None],))
curs.execute("insert into na (inta) values (%s)", ([42, None],))
curs.execute("insert into na (inta) values (%s)", ([None, None],))
curs.execute("insert into na (boola) values (%s)", ([None],))
curs.execute("insert into na (boola) values (%s)", ([True, None],))
curs.execute("insert into na (boola) values (%s)", ([None, None],))
@ -220,7 +226,7 @@ class TypesBasicTests(ConnectingTestCase):
curs.execute("insert into na (textaa) values (%s)", ([['a', None]],))
# curs.execute("insert into na (textaa) values (%s)", ([[None, None]],))
# curs.execute("insert into na (intaa) values (%s)", ([[None]],))
curs.execute("insert into na (intaa) values (%s)", ([[42, None]],))
curs.execute("insert into na (intaa) values (%s)", ([[42, None]],))
# curs.execute("insert into na (intaa) values (%s)", ([[None, None]],))
# curs.execute("insert into na (boolaa) values (%s)", ([[None]],))
curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],))
@ -323,30 +329,33 @@ class TypesBasicTests(ConnectingTestCase):
self.assertEqual(1, l1)
def testGenericArray(self):
a = self.execute("select '{1,2,3}'::int4[]")
self.assertEqual(a, [1,2,3])
a = self.execute("select array['a','b','''']::text[]")
self.assertEqual(a, ['a','b',"'"])
a = self.execute("select '{1, 2, 3}'::int4[]")
self.assertEqual(a, [1, 2, 3])
a = self.execute("select array['a', 'b', '''']::text[]")
self.assertEqual(a, ['a', 'b', "'"])
@testutils.skip_before_postgres(8, 2)
def testGenericArrayNull(self):
def caster(s, cur):
if s is None: return "nada"
if s is None:
return "nada"
return int(s) * 2
base = psycopg2.extensions.new_type((23,), "INT4", caster)
array = psycopg2.extensions.new_array_type((1007,), "INT4ARRAY", base)
psycopg2.extensions.register_type(array, self.conn)
a = self.execute("select '{1,2,3}'::int4[]")
self.assertEqual(a, [2,4,6])
a = self.execute("select '{1,2,NULL}'::int4[]")
self.assertEqual(a, [2,4,'nada'])
a = self.execute("select '{1, 2, 3}'::int4[]")
self.assertEqual(a, [2, 4, 6])
a = self.execute("select '{1, 2, NULL}'::int4[]")
self.assertEqual(a, [2, 4, 'nada'])
class AdaptSubclassTest(unittest.TestCase):
def test_adapt_subtype(self):
from psycopg2.extensions import adapt
class Sub(str): pass
class Sub(str):
pass
s1 = "hel'lo"
s2 = Sub(s1)
self.assertEqual(adapt(s1).getquoted(), adapt(s2).getquoted())
@ -354,9 +363,14 @@ class AdaptSubclassTest(unittest.TestCase):
def test_adapt_most_specific(self):
from psycopg2.extensions import adapt, register_adapter, AsIs
class A(object): pass
class B(A): pass
class C(B): pass
class A(object):
pass
class B(A):
pass
class C(B):
pass
register_adapter(A, lambda a: AsIs("a"))
register_adapter(B, lambda b: AsIs("b"))
@ -370,8 +384,11 @@ class AdaptSubclassTest(unittest.TestCase):
def test_no_mro_no_joy(self):
from psycopg2.extensions import adapt, register_adapter, AsIs
class A: pass
class B(A): pass
class A:
pass
class B(A):
pass
register_adapter(A, lambda a: AsIs("a"))
try:
@ -383,8 +400,11 @@ class AdaptSubclassTest(unittest.TestCase):
def test_adapt_subtype_3(self):
from psycopg2.extensions import adapt, register_adapter, AsIs
class A: pass
class B(A): pass
class A:
pass
class B(A):
pass
register_adapter(A, lambda a: AsIs("a"))
try:
@ -443,7 +463,8 @@ class ByteaParserTest(unittest.TestCase):
def test_full_hex(self, upper=False):
buf = ''.join(("%02x" % i) for i in range(256))
if upper: buf = buf.upper()
if upper:
buf = buf.upper()
buf = '\\x' + buf
rv = self.cast(buf.encode('utf8'))
if sys.version_info[0] < 3:

View File

@ -37,6 +37,7 @@ def filter_scs(conn, s):
else:
return s.replace(b"E'", b"'")
class TypesExtrasTests(ConnectingTestCase):
"""Test that all type conversions are working."""
@ -60,7 +61,8 @@ class TypesExtrasTests(ConnectingTestCase):
def testUUIDARRAY(self):
import uuid
psycopg2.extras.register_uuid()
u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'), uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e352')]
u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'),
uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e352')]
s = self.execute("SELECT %s AS foo", (u,))
self.failUnless(u == s)
# array with a NULL element
@ -110,7 +112,8 @@ class TypesExtrasTests(ConnectingTestCase):
a.getquoted())
def test_adapt_fail(self):
class Foo(object): pass
class Foo(object):
pass
self.assertRaises(psycopg2.ProgrammingError,
psycopg2.extensions.adapt, Foo(), ext.ISQLQuote, None)
try:
@ -130,6 +133,7 @@ def skip_if_no_hstore(f):
return skip_if_no_hstore_
class HstoreTestCase(ConnectingTestCase):
def test_adapt_8(self):
if self.conn.server_version >= 90000:
@ -155,7 +159,8 @@ class HstoreTestCase(ConnectingTestCase):
self.assertEqual(ii[2], filter_scs(self.conn, b"(E'c' => NULL)"))
if 'd' in o:
encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding])
self.assertEqual(ii[3], filter_scs(self.conn, b"(E'd' => E'" + encc + b"')"))
self.assertEqual(ii[3],
filter_scs(self.conn, b"(E'd' => E'" + encc + b"')"))
def test_adapt_9(self):
if self.conn.server_version < 90000:
@ -199,7 +204,7 @@ class HstoreTestCase(ConnectingTestCase):
ok(None, None)
ok('', {})
ok('"a"=>"1", "b"=>"2"', {'a': '1', 'b': '2'})
ok('"a" => "1" ,"b" => "2"', {'a': '1', 'b': '2'})
ok('"a" => "1" , "b" => "2"', {'a': '1', 'b': '2'})
ok('"a"=>NULL, "b"=>"2"', {'a': None, 'b': '2'})
ok(r'"a"=>"\"", "\""=>"2"', {'a': '"', '"': '2'})
ok('"a"=>"\'", "\'"=>"2"', {'a': "'", "'": '2'})
@ -402,7 +407,9 @@ class HstoreTestCase(ConnectingTestCase):
from psycopg2.extras import register_hstore
register_hstore(None, globally=True, oid=oid, array_oid=aoid)
try:
cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore, '{a=>b}'::hstore[]")
cur.execute("""
select null::hstore, ''::hstore,
'a => b'::hstore, '{a=>b}'::hstore[]""")
t = cur.fetchone()
self.assert_(t[0] is None)
self.assertEqual(t[1], {})
@ -449,6 +456,7 @@ def skip_if_no_composite(f):
return skip_if_no_composite_
class AdaptTypeTestCase(ConnectingTestCase):
@skip_if_no_composite
def test_none_in_record(self):
@ -463,8 +471,11 @@ class AdaptTypeTestCase(ConnectingTestCase):
# the None adapter is not actually invoked in regular adaptation
class WonkyAdapter(object):
def __init__(self, obj): pass
def getquoted(self): return "NOPE!"
def __init__(self, obj):
pass
def getquoted(self):
return "NOPE!"
curs = self.conn.cursor()
@ -481,6 +492,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
def test_tokenization(self):
from psycopg2.extras import CompositeCaster
def ok(s, v):
self.assertEqual(CompositeCaster.tokenize(s), v)
@ -519,26 +531,26 @@ class AdaptTypeTestCase(ConnectingTestCase):
self.assertEqual(t.oid, oid)
self.assert_(issubclass(t.type, tuple))
self.assertEqual(t.attnames, ['anint', 'astring', 'adate'])
self.assertEqual(t.atttypes, [23,25,1082])
self.assertEqual(t.atttypes, [23, 25, 1082])
curs = self.conn.cursor()
r = (10, 'hello', date(2011,1,2))
r = (10, 'hello', date(2011, 1, 2))
curs.execute("select %s::type_isd;", (r,))
v = curs.fetchone()[0]
self.assert_(isinstance(v, t.type))
self.assertEqual(v[0], 10)
self.assertEqual(v[1], "hello")
self.assertEqual(v[2], date(2011,1,2))
self.assertEqual(v[2], date(2011, 1, 2))
try:
from collections import namedtuple
from collections import namedtuple # noqa
except ImportError:
pass
else:
self.assert_(t.type is not tuple)
self.assertEqual(v.anint, 10)
self.assertEqual(v.astring, "hello")
self.assertEqual(v.adate, date(2011,1,2))
self.assertEqual(v.adate, date(2011, 1, 2))
@skip_if_no_composite
def test_empty_string(self):
@ -574,14 +586,14 @@ class AdaptTypeTestCase(ConnectingTestCase):
psycopg2.extras.register_composite("type_r_ft", self.conn)
curs = self.conn.cursor()
r = (0.25, (date(2011,1,2), (42, "hello")))
r = (0.25, (date(2011, 1, 2), (42, "hello")))
curs.execute("select %s::type_r_ft;", (r,))
v = curs.fetchone()[0]
self.assertEqual(r, v)
try:
from collections import namedtuple
from collections import namedtuple # noqa
except ImportError:
pass
else:
@ -595,7 +607,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs2 = self.conn.cursor()
psycopg2.extras.register_composite("type_ii", curs1)
curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1,2))
self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii")
self.assertEqual(curs2.fetchone()[0], "(1,2)")
@ -610,7 +622,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs1 = conn1.cursor()
curs2 = conn2.cursor()
curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1,2))
self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii")
self.assertEqual(curs2.fetchone()[0], "(1,2)")
finally:
@ -629,9 +641,9 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs1 = conn1.cursor()
curs2 = conn2.cursor()
curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1,2))
self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii")
self.assertEqual(curs2.fetchone()[0], (1,2))
self.assertEqual(curs2.fetchone()[0], (1, 2))
finally:
# drop the registered typecasters to help the refcounting
# script to return precise values.
@ -661,30 +673,30 @@ class AdaptTypeTestCase(ConnectingTestCase):
"typens.typens_ii", self.conn)
self.assertEqual(t.schema, 'typens')
curs.execute("select (4,8)::typens.typens_ii")
self.assertEqual(curs.fetchone()[0], (4,8))
self.assertEqual(curs.fetchone()[0], (4, 8))
@skip_if_no_composite
@skip_before_postgres(8, 4)
def test_composite_array(self):
oid = self._create_type("type_isd",
self._create_type("type_isd",
[('anint', 'integer'), ('astring', 'text'), ('adate', 'date')])
t = psycopg2.extras.register_composite("type_isd", self.conn)
curs = self.conn.cursor()
r1 = (10, 'hello', date(2011,1,2))
r2 = (20, 'world', date(2011,1,3))
r1 = (10, 'hello', date(2011, 1, 2))
r2 = (20, 'world', date(2011, 1, 3))
curs.execute("select %s::type_isd[];", ([r1, r2],))
v = curs.fetchone()[0]
self.assertEqual(len(v), 2)
self.assert_(isinstance(v[0], t.type))
self.assertEqual(v[0][0], 10)
self.assertEqual(v[0][1], "hello")
self.assertEqual(v[0][2], date(2011,1,2))
self.assertEqual(v[0][2], date(2011, 1, 2))
self.assert_(isinstance(v[1], t.type))
self.assertEqual(v[1][0], 20)
self.assertEqual(v[1][1], "world")
self.assertEqual(v[1][2], date(2011,1,3))
self.assertEqual(v[1][2], date(2011, 1, 3))
@skip_if_no_composite
def test_wrong_schema(self):
@ -752,7 +764,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
register_composite('type_ii', conn)
curs = conn.cursor()
curs.execute("select '(1,2)'::type_ii as x")
self.assertEqual(curs.fetchone()['x'], (1,2))
self.assertEqual(curs.fetchone()['x'], (1, 2))
finally:
conn.close()
@ -761,7 +773,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs = conn.cursor()
register_composite('type_ii', conn)
curs.execute("select '(1,2)'::type_ii as x")
self.assertEqual(curs.fetchone()['x'], (1,2))
self.assertEqual(curs.fetchone()['x'], (1, 2))
finally:
conn.close()
@ -782,13 +794,13 @@ class AdaptTypeTestCase(ConnectingTestCase):
self.assertEqual(t.oid, oid)
curs = self.conn.cursor()
r = (10, 'hello', date(2011,1,2))
r = (10, 'hello', date(2011, 1, 2))
curs.execute("select %s::type_isd;", (r,))
v = curs.fetchone()[0]
self.assert_(isinstance(v, dict))
self.assertEqual(v['anint'], 10)
self.assertEqual(v['astring'], "hello")
self.assertEqual(v['adate'], date(2011,1,2))
self.assertEqual(v['adate'], date(2011, 1, 2))
def _create_type(self, name, fields):
curs = self.conn.cursor()
@ -825,6 +837,7 @@ def skip_if_json_module(f):
return skip_if_json_module_
def skip_if_no_json_module(f):
"""Skip a test if no Python json module is available"""
@wraps(f)
@ -836,6 +849,7 @@ def skip_if_no_json_module(f):
return skip_if_no_json_module_
def skip_if_no_json_type(f):
"""Skip a test if PostgreSQL json type is not available"""
@wraps(f)
@ -849,6 +863,7 @@ def skip_if_no_json_type(f):
return skip_if_no_json_type_
class JsonTestCase(ConnectingTestCase):
@skip_if_json_module
def test_module_not_available(self):
@ -858,6 +873,7 @@ class JsonTestCase(ConnectingTestCase):
@skip_if_json_module
def test_customizable_with_module_not_available(self):
from psycopg2.extras import Json
class MyJson(Json):
def dumps(self, obj):
assert obj is None
@ -870,7 +886,7 @@ class JsonTestCase(ConnectingTestCase):
from psycopg2.extras import json, Json
objs = [None, "te'xt", 123, 123.45,
u'\xe0\u20ac', ['a', 100], {'a': 100} ]
u'\xe0\u20ac', ['a', 100], {'a': 100}]
curs = self.conn.cursor()
for obj in enumerate(objs):
@ -889,7 +905,9 @@ class JsonTestCase(ConnectingTestCase):
curs = self.conn.cursor()
obj = Decimal('123.45')
dumps = lambda obj: json.dumps(obj, cls=DecimalEncoder)
def dumps(obj):
return json.dumps(obj, cls=DecimalEncoder)
self.assertEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)),
b"'123.45'")
@ -921,8 +939,7 @@ class JsonTestCase(ConnectingTestCase):
obj = {'a': 123}
self.assertEqual(curs.mogrify("%s", (obj,)), b"""'{"a": 123}'""")
finally:
del psycopg2.extensions.adapters[dict, ext.ISQLQuote]
del psycopg2.extensions.adapters[dict, ext.ISQLQuote]
def test_type_not_available(self):
curs = self.conn.cursor()
@ -982,7 +999,9 @@ class JsonTestCase(ConnectingTestCase):
@skip_if_no_json_type
def test_loads(self):
json = psycopg2.extras.json
loads = lambda x: json.loads(x, parse_float=Decimal)
def loads(s):
return json.loads(s, parse_float=Decimal)
psycopg2.extras.register_json(self.conn, loads=loads)
curs = self.conn.cursor()
curs.execute("""select '{"a": 100.0, "b": null}'::json""")
@ -998,7 +1017,9 @@ class JsonTestCase(ConnectingTestCase):
old = psycopg2.extensions.string_types.get(114)
olda = psycopg2.extensions.string_types.get(199)
loads = lambda x: psycopg2.extras.json.loads(x, parse_float=Decimal)
def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal)
try:
new, newa = psycopg2.extras.register_json(
loads=loads, oid=oid, array_oid=array_oid)
@ -1020,7 +1041,8 @@ class JsonTestCase(ConnectingTestCase):
def test_register_default(self):
curs = self.conn.cursor()
loads = lambda x: psycopg2.extras.json.loads(x, parse_float=Decimal)
def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal)
psycopg2.extras.register_default_json(curs, loads=loads)
curs.execute("""select '{"a": 100.0, "b": null}'::json""")
@ -1070,6 +1092,7 @@ class JsonTestCase(ConnectingTestCase):
def skip_if_no_jsonb_type(f):
return skip_before_postgres(9, 4)(f)
class JsonbTestCase(ConnectingTestCase):
@staticmethod
def myloads(s):
@ -1118,7 +1141,10 @@ class JsonbTestCase(ConnectingTestCase):
def test_loads(self):
json = psycopg2.extras.json
loads = lambda x: json.loads(x, parse_float=Decimal)
def loads(s):
return json.loads(s, parse_float=Decimal)
psycopg2.extras.register_json(self.conn, loads=loads, name='jsonb')
curs = self.conn.cursor()
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
@ -1134,7 +1160,9 @@ class JsonbTestCase(ConnectingTestCase):
def test_register_default(self):
curs = self.conn.cursor()
loads = lambda x: psycopg2.extras.json.loads(x, parse_float=Decimal)
def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal)
psycopg2.extras.register_default_jsonb(curs, loads=loads)
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
@ -1200,7 +1228,7 @@ class RangeTestCase(unittest.TestCase):
('[)', True, False),
('(]', False, True),
('()', False, False),
('[]', True, True),]:
('[]', True, True)]:
r = Range(10, 20, bounds)
self.assertEqual(r.lower, 10)
self.assertEqual(r.upper, 20)
@ -1294,11 +1322,11 @@ class RangeTestCase(unittest.TestCase):
self.assert_(not Range(empty=True))
def test_eq_hash(self):
from psycopg2.extras import Range
def assert_equal(r1, r2):
self.assert_(r1 == r2)
self.assert_(hash(r1) == hash(r2))
from psycopg2.extras import Range
assert_equal(Range(empty=True), Range(empty=True))
assert_equal(Range(), Range())
assert_equal(Range(10, None), Range(10, None))
@ -1321,8 +1349,11 @@ class RangeTestCase(unittest.TestCase):
def test_eq_subclass(self):
from psycopg2.extras import Range, NumericRange
class IntRange(NumericRange): pass
class PositiveIntRange(IntRange): pass
class IntRange(NumericRange):
pass
class PositiveIntRange(IntRange):
pass
self.assertEqual(Range(10, 20), IntRange(10, 20))
self.assertEqual(PositiveIntRange(10, 20), IntRange(10, 20))
@ -1480,8 +1511,8 @@ class RangeCasterTestCase(ConnectingTestCase):
r = cur.fetchone()[0]
self.assert_(isinstance(r, DateRange))
self.assert_(not r.isempty)
self.assertEqual(r.lower, date(2000,1,2))
self.assertEqual(r.upper, date(2012,12,31))
self.assertEqual(r.lower, date(2000, 1, 2))
self.assertEqual(r.upper, date(2012, 12, 31))
self.assert_(not r.lower_inf)
self.assert_(not r.upper_inf)
self.assert_(r.lower_inc)
@ -1490,8 +1521,8 @@ class RangeCasterTestCase(ConnectingTestCase):
def test_cast_timestamp(self):
from psycopg2.extras import DateTimeRange
cur = self.conn.cursor()
ts1 = datetime(2000,1,1)
ts2 = datetime(2000,12,31,23,59,59,999)
ts1 = datetime(2000, 1, 1)
ts2 = datetime(2000, 12, 31, 23, 59, 59, 999)
cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2))
r = cur.fetchone()[0]
self.assert_(isinstance(r, DateTimeRange))
@ -1507,8 +1538,9 @@ class RangeCasterTestCase(ConnectingTestCase):
from psycopg2.extras import DateTimeTZRange
from psycopg2.tz import FixedOffsetTimezone
cur = self.conn.cursor()
ts1 = datetime(2000,1,1, tzinfo=FixedOffsetTimezone(600))
ts2 = datetime(2000,12,31,23,59,59,999, tzinfo=FixedOffsetTimezone(600))
ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600))
ts2 = datetime(2000, 12, 31, 23, 59, 59, 999,
tzinfo=FixedOffsetTimezone(600))
cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2))
r = cur.fetchone()[0]
self.assert_(isinstance(r, DateTimeTZRange))
@ -1598,8 +1630,9 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(isinstance(r1, DateTimeRange))
self.assert_(r1.isempty)
ts1 = datetime(2000,1,1, tzinfo=FixedOffsetTimezone(600))
ts2 = datetime(2000,12,31,23,59,59,999, tzinfo=FixedOffsetTimezone(600))
ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600))
ts2 = datetime(2000, 12, 31, 23, 59, 59, 999,
tzinfo=FixedOffsetTimezone(600))
r = DateTimeTZRange(ts1, ts2, '(]')
cur.execute("select %s", (r,))
r1 = cur.fetchone()[0]
@ -1627,7 +1660,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(not r1.lower_inc)
self.assert_(r1.upper_inc)
cur.execute("select %s", ([r,r,r],))
cur.execute("select %s", ([r, r, r],))
rs = cur.fetchone()[0]
self.assertEqual(len(rs), 3)
for r1 in rs:
@ -1651,12 +1684,12 @@ class RangeCasterTestCase(ConnectingTestCase):
id integer primary key,
range textrange)""")
bounds = [ '[)', '(]', '()', '[]' ]
ranges = [ TextRange(low, up, bounds[i % 4])
bounds = ['[)', '(]', '()', '[]']
ranges = [TextRange(low, up, bounds[i % 4])
for i, (low, up) in enumerate(zip(
[None] + map(chr, range(1, 128)),
map(chr, range(1,128)) + [None],
))]
map(chr, range(1, 128)) + [None],
))]
ranges.append(TextRange())
ranges.append(TextRange(empty=True))
@ -1736,6 +1769,6 @@ decorate_all_tests(RangeCasterTestCase, skip_if_no_range)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

View File

@ -30,6 +30,7 @@ import psycopg2.extensions as ext
from testutils import unittest, ConnectingTestCase
class WithTestCase(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
@ -93,7 +94,7 @@ class WithConnectionTestCase(WithTestCase):
with self.conn as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (3)")
1/0
1 / 0
self.assertRaises(ZeroDivisionError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY)
@ -113,6 +114,7 @@ class WithConnectionTestCase(WithTestCase):
def test_subclass_commit(self):
commits = []
class MyConn(ext.connection):
def commit(self):
commits.append(None)
@ -131,6 +133,7 @@ class WithConnectionTestCase(WithTestCase):
def test_subclass_rollback(self):
rollbacks = []
class MyConn(ext.connection):
def rollback(self):
rollbacks.append(None)
@ -140,7 +143,7 @@ class WithConnectionTestCase(WithTestCase):
with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (11)")
1/0
1 / 0
except ZeroDivisionError:
pass
else:
@ -175,7 +178,7 @@ class WithCursorTestCase(WithTestCase):
with self.conn as conn:
with conn.cursor() as curs:
curs.execute("insert into test_with values (5)")
1/0
1 / 0
except ZeroDivisionError:
pass
@ -189,6 +192,7 @@ class WithCursorTestCase(WithTestCase):
def test_subclass(self):
closes = []
class MyCurs(ext.cursor):
def close(self):
closes.append(None)

View File

@ -69,8 +69,8 @@ else:
# Silence warnings caused by the stubbornness of the Python unittest
# maintainers
# http://bugs.python.org/issue9424
if not hasattr(unittest.TestCase, 'assert_') \
or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue:
if (not hasattr(unittest.TestCase, 'assert_')
or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue):
# mavaff...
unittest.TestCase.assert_ = unittest.TestCase.assertTrue
unittest.TestCase.failUnless = unittest.TestCase.assertTrue
@ -175,7 +175,7 @@ def skip_if_no_uuid(f):
@wraps(f)
def skip_if_no_uuid_(self):
try:
import uuid
import uuid # noqa
except ImportError:
return self.skipTest("uuid not available in this Python version")
@ -223,7 +223,7 @@ def skip_if_no_namedtuple(f):
@wraps(f)
def skip_if_no_namedtuple_(self):
try:
from collections import namedtuple
from collections import namedtuple # noqa
except ImportError:
return self.skipTest("collections.namedtuple not available")
else:
@ -237,7 +237,7 @@ def skip_if_no_iobase(f):
@wraps(f)
def skip_if_no_iobase_(self):
try:
from io import TextIOBase
from io import TextIOBase # noqa
except ImportError:
return self.skipTest("io.TextIOBase not found.")
else:
@ -249,6 +249,7 @@ def skip_if_no_iobase(f):
def skip_before_postgres(*ver):
"""Skip a test on PostgreSQL before a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_before_postgres_(f):
@wraps(f)
def skip_before_postgres__(self):
@ -261,9 +262,11 @@ def skip_before_postgres(*ver):
return skip_before_postgres__
return skip_before_postgres_
def skip_after_postgres(*ver):
"""Skip a test on PostgreSQL after (including) a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_after_postgres_(f):
@wraps(f)
def skip_after_postgres__(self):
@ -276,6 +279,7 @@ def skip_after_postgres(*ver):
return skip_after_postgres__
return skip_after_postgres_
def libpq_version():
import psycopg2
v = psycopg2.__libpq_version__
@ -283,9 +287,11 @@ def libpq_version():
v = psycopg2.extensions.libpq_version()
return v
def skip_before_libpq(*ver):
"""Skip a test if libpq we're linked to is older than a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_before_libpq_(f):
@wraps(f)
def skip_before_libpq__(self):
@ -298,9 +304,11 @@ def skip_before_libpq(*ver):
return skip_before_libpq__
return skip_before_libpq_
def skip_after_libpq(*ver):
"""Skip a test if libpq we're linked to is newer than a certain version."""
ver = ver + (0,) * (3 - len(ver))
def skip_after_libpq_(f):
@wraps(f)
def skip_after_libpq__(self):
@ -313,6 +321,7 @@ def skip_after_libpq(*ver):
return skip_after_libpq__
return skip_after_libpq_
def skip_before_python(*ver):
"""Skip a test on Python before a certain version."""
def skip_before_python_(f):
@ -327,6 +336,7 @@ def skip_before_python(*ver):
return skip_before_python__
return skip_before_python_
def skip_from_python(*ver):
"""Skip a test on Python after (including) a certain version."""
def skip_from_python_(f):
@ -341,6 +351,7 @@ def skip_from_python(*ver):
return skip_from_python__
return skip_from_python_
def skip_if_no_superuser(f):
"""Skip a test if the database user running the test is not a superuser"""
@wraps(f)
@ -357,6 +368,7 @@ def skip_if_no_superuser(f):
return skip_if_no_superuser_
def skip_if_green(reason):
def skip_if_green_(f):
@wraps(f)
@ -372,6 +384,7 @@ def skip_if_green(reason):
skip_copy_if_green = skip_if_green("copy in async mode currently not supported")
def skip_if_no_getrefcount(f):
@wraps(f)
def skip_if_no_getrefcount_(self):
@ -381,6 +394,7 @@ def skip_if_no_getrefcount(f):
return f(self)
return skip_if_no_getrefcount_
def skip_if_windows(f):
"""Skip a test if run on windows"""
@wraps(f)
@ -419,6 +433,7 @@ def script_to_py3(script):
f2.close()
os.remove(filename)
class py3_raises_typeerror(object):
def __enter__(self):

View File

@ -8,3 +8,8 @@ envlist = py26, py27
[testenv]
commands = make check
[flake8]
max-line-length = 85
ignore = E128, W503
exclude = build, doc, sandbox, examples, tests/dbapi20.py