Python source cleanup using flake8

This commit is contained in:
Daniele Varrazzo 2016-10-11 00:10:53 +01:00
parent 4458c9b4c9
commit 91d2158de7
35 changed files with 644 additions and 432 deletions

View File

@ -47,19 +47,20 @@ Homepage: http://initd.org/projects/psycopg2
# Import the DBAPI-2.0 stuff into top-level module. # Import the DBAPI-2.0 stuff into top-level module.
from psycopg2._psycopg import BINARY, NUMBER, STRING, DATETIME, ROWID from psycopg2._psycopg import ( # noqa
BINARY, NUMBER, STRING, DATETIME, ROWID,
from psycopg2._psycopg import Binary, Date, Time, Timestamp Binary, Date, Time, Timestamp,
from psycopg2._psycopg import DateFromTicks, TimeFromTicks, TimestampFromTicks DateFromTicks, TimeFromTicks, TimestampFromTicks,
from psycopg2._psycopg import Error, Warning, DataError, DatabaseError, ProgrammingError Error, Warning, DataError, DatabaseError, ProgrammingError, IntegrityError,
from psycopg2._psycopg import IntegrityError, InterfaceError, InternalError InterfaceError, InternalError, NotSupportedError, OperationalError,
from psycopg2._psycopg import NotSupportedError, OperationalError
from psycopg2._psycopg import _connect, apilevel, threadsafety, paramstyle _connect, apilevel, threadsafety, paramstyle,
from psycopg2._psycopg import __version__, __libpq_version__ __version__, __libpq_version__,
)
from psycopg2 import tz from psycopg2 import tz # noqa
# Register default adapters. # Register default adapters.

View File

@ -34,7 +34,7 @@ from psycopg2._psycopg import new_type, new_array_type, register_type
# import the best json implementation available # import the best json implementation available
if sys.version_info[:2] >= (2,6): if sys.version_info[:2] >= (2, 6):
import json import json
else: else:
try: try:
@ -51,6 +51,7 @@ JSONARRAY_OID = 199
JSONB_OID = 3802 JSONB_OID = 3802
JSONBARRAY_OID = 3807 JSONBARRAY_OID = 3807
class Json(object): class Json(object):
""" """
An `~psycopg2.extensions.ISQLQuote` wrapper to adapt a Python object to An `~psycopg2.extensions.ISQLQuote` wrapper to adapt a Python object to
@ -143,6 +144,7 @@ def register_json(conn_or_curs=None, globally=False, loads=None,
return JSON, JSONARRAY return JSON, JSONARRAY
def register_default_json(conn_or_curs=None, globally=False, loads=None): def register_default_json(conn_or_curs=None, globally=False, loads=None):
""" """
Create and register :sql:`json` typecasters for PostgreSQL 9.2 and following. Create and register :sql:`json` typecasters for PostgreSQL 9.2 and following.
@ -155,6 +157,7 @@ def register_default_json(conn_or_curs=None, globally=False, loads=None):
return register_json(conn_or_curs=conn_or_curs, globally=globally, return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSON_OID, array_oid=JSONARRAY_OID) loads=loads, oid=JSON_OID, array_oid=JSONARRAY_OID)
def register_default_jsonb(conn_or_curs=None, globally=False, loads=None): def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
""" """
Create and register :sql:`jsonb` typecasters for PostgreSQL 9.4 and following. Create and register :sql:`jsonb` typecasters for PostgreSQL 9.4 and following.
@ -167,6 +170,7 @@ def register_default_jsonb(conn_or_curs=None, globally=False, loads=None):
return register_json(conn_or_curs=conn_or_curs, globally=globally, return register_json(conn_or_curs=conn_or_curs, globally=globally,
loads=loads, oid=JSONB_OID, array_oid=JSONBARRAY_OID, name='jsonb') loads=loads, oid=JSONB_OID, array_oid=JSONBARRAY_OID, name='jsonb')
def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'): def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
"""Create typecasters for json data type.""" """Create typecasters for json data type."""
if loads is None: if loads is None:
@ -188,6 +192,7 @@ def _create_json_typecasters(oid, array_oid, loads=None, name='JSON'):
return JSON, JSONARRAY return JSON, JSONARRAY
def _get_json_oids(conn_or_curs, name='json'): def _get_json_oids(conn_or_curs, name='json'):
# lazy imports # lazy imports
from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extensions import STATUS_IN_TRANSACTION
@ -215,6 +220,3 @@ def _get_json_oids(conn_or_curs, name='json'):
raise conn.ProgrammingError("%s data type not found" % name) raise conn.ProgrammingError("%s data type not found" % name)
return r return r

View File

@ -30,6 +30,7 @@ from psycopg2._psycopg import ProgrammingError, InterfaceError
from psycopg2.extensions import ISQLQuote, adapt, register_adapter from psycopg2.extensions import ISQLQuote, adapt, register_adapter
from psycopg2.extensions import new_type, new_array_type, register_type from psycopg2.extensions import new_type, new_array_type, register_type
class Range(object): class Range(object):
"""Python representation for a PostgreSQL |range|_ type. """Python representation for a PostgreSQL |range|_ type.
@ -78,42 +79,50 @@ class Range(object):
@property @property
def lower_inf(self): def lower_inf(self):
"""`!True` if the range doesn't have a lower bound.""" """`!True` if the range doesn't have a lower bound."""
if self._bounds is None: return False if self._bounds is None:
return False
return self._lower is None return self._lower is None
@property @property
def upper_inf(self): def upper_inf(self):
"""`!True` if the range doesn't have an upper bound.""" """`!True` if the range doesn't have an upper bound."""
if self._bounds is None: return False if self._bounds is None:
return False
return self._upper is None return self._upper is None
@property @property
def lower_inc(self): def lower_inc(self):
"""`!True` if the lower bound is included in the range.""" """`!True` if the lower bound is included in the range."""
if self._bounds is None: return False if self._bounds is None or self._lower is None:
if self._lower is None: return False return False
return self._bounds[0] == '[' return self._bounds[0] == '['
@property @property
def upper_inc(self): def upper_inc(self):
"""`!True` if the upper bound is included in the range.""" """`!True` if the upper bound is included in the range."""
if self._bounds is None: return False if self._bounds is None or self._upper is None:
if self._upper is None: return False return False
return self._bounds[1] == ']' return self._bounds[1] == ']'
def __contains__(self, x): def __contains__(self, x):
if self._bounds is None: return False if self._bounds is None:
return False
if self._lower is not None: if self._lower is not None:
if self._bounds[0] == '[': if self._bounds[0] == '[':
if x < self._lower: return False if x < self._lower:
return False
else: else:
if x <= self._lower: return False if x <= self._lower:
return False
if self._upper is not None: if self._upper is not None:
if self._bounds[1] == ']': if self._bounds[1] == ']':
if x > self._upper: return False if x > self._upper:
return False
else: else:
if x >= self._upper: return False if x >= self._upper:
return False
return True return True
@ -295,7 +304,8 @@ class RangeCaster(object):
self.adapter.name = pgrange self.adapter.name = pgrange
else: else:
try: try:
if issubclass(pgrange, RangeAdapter) and pgrange is not RangeAdapter: if issubclass(pgrange, RangeAdapter) \
and pgrange is not RangeAdapter:
self.adapter = pgrange self.adapter = pgrange
except TypeError: except TypeError:
pass pass
@ -436,14 +446,17 @@ class NumericRange(Range):
""" """
pass pass
class DateRange(Range): class DateRange(Range):
"""Represents :sql:`daterange` values.""" """Represents :sql:`daterange` values."""
pass pass
class DateTimeRange(Range): class DateTimeRange(Range):
"""Represents :sql:`tsrange` values.""" """Represents :sql:`tsrange` values."""
pass pass
class DateTimeTZRange(Range): class DateTimeTZRange(Range):
"""Represents :sql:`tstzrange` values.""" """Represents :sql:`tstzrange` values."""
pass pass
@ -508,5 +521,3 @@ tsrange_caster._register()
tstzrange_caster = RangeCaster('tstzrange', DateTimeTZRange, tstzrange_caster = RangeCaster('tstzrange', DateTimeTZRange,
oid=3910, subtype_oid=1184, array_oid=3911) oid=3910, subtype_oid=1184, array_oid=3911)
tstzrange_caster._register() tstzrange_caster._register()

View File

@ -29,6 +29,7 @@ This module contains symbolic names for all PostgreSQL error codes.
# http://www.postgresql.org/docs/current/static/errcodes-appendix.html # http://www.postgresql.org/docs/current/static/errcodes-appendix.html
# #
def lookup(code, _cache={}): def lookup(code, _cache={}):
"""Lookup an error code or class code and return its symbolic name. """Lookup an error code or class code and return its symbolic name.

View File

@ -33,43 +33,38 @@ This module holds all the extensions to the DBAPI-2.0 provided by psycopg.
# License for more details. # License for more details.
import re as _re import re as _re
import sys as _sys
from psycopg2._psycopg import UNICODE, INTEGER, LONGINTEGER, BOOLEAN, FLOAT from psycopg2._psycopg import ( # noqa
from psycopg2._psycopg import TIME, DATE, INTERVAL, DECIMAL BINARYARRAY, BOOLEAN, BOOLEANARRAY, DATE, DATEARRAY, DATETIMEARRAY,
from psycopg2._psycopg import BINARYARRAY, BOOLEANARRAY, DATEARRAY, DATETIMEARRAY DECIMAL, DECIMALARRAY, FLOAT, FLOATARRAY, INTEGER, INTEGERARRAY,
from psycopg2._psycopg import DECIMALARRAY, FLOATARRAY, INTEGERARRAY, INTERVALARRAY INTERVAL, INTERVALARRAY, LONGINTEGER, LONGINTEGERARRAY, ROWIDARRAY,
from psycopg2._psycopg import LONGINTEGERARRAY, ROWIDARRAY, STRINGARRAY, TIMEARRAY STRINGARRAY, TIME, TIMEARRAY, UNICODE, UNICODEARRAY,
from psycopg2._psycopg import UNICODEARRAY AsIs, Binary, Boolean, Float, Int, QuotedString, )
from psycopg2._psycopg import Binary, Boolean, Int, Float, QuotedString, AsIs
try: try:
from psycopg2._psycopg import MXDATE, MXDATETIME, MXINTERVAL, MXTIME from psycopg2._psycopg import ( # noqa
from psycopg2._psycopg import MXDATEARRAY, MXDATETIMEARRAY, MXINTERVALARRAY, MXTIMEARRAY MXDATE, MXDATETIME, MXINTERVAL, MXTIME,
from psycopg2._psycopg import DateFromMx, TimeFromMx, TimestampFromMx MXDATEARRAY, MXDATETIMEARRAY, MXINTERVALARRAY, MXTIMEARRAY,
from psycopg2._psycopg import IntervalFromMx DateFromMx, TimeFromMx, TimestampFromMx, IntervalFromMx, )
except ImportError: except ImportError:
pass pass
try: try:
from psycopg2._psycopg import PYDATE, PYDATETIME, PYINTERVAL, PYTIME from psycopg2._psycopg import ( # noqa
from psycopg2._psycopg import PYDATEARRAY, PYDATETIMEARRAY, PYINTERVALARRAY, PYTIMEARRAY PYDATE, PYDATETIME, PYINTERVAL, PYTIME,
from psycopg2._psycopg import DateFromPy, TimeFromPy, TimestampFromPy PYDATEARRAY, PYDATETIMEARRAY, PYINTERVALARRAY, PYTIMEARRAY,
from psycopg2._psycopg import IntervalFromPy DateFromPy, TimeFromPy, TimestampFromPy, IntervalFromPy, )
except ImportError: except ImportError:
pass pass
from psycopg2._psycopg import adapt, adapters, encodings, connection, cursor from psycopg2._psycopg import ( # noqa
from psycopg2._psycopg import lobject, Xid, libpq_version, parse_dsn, quote_ident adapt, adapters, encodings, connection, cursor,
from psycopg2._psycopg import string_types, binary_types, new_type, new_array_type, register_type lobject, Xid, libpq_version, parse_dsn, quote_ident,
from psycopg2._psycopg import ISQLQuote, Notify, Diagnostics, Column string_types, binary_types, new_type, new_array_type, register_type,
ISQLQuote, Notify, Diagnostics, Column,
QueryCanceledError, TransactionRollbackError,
set_wait_callback, get_wait_callback, )
from psycopg2._psycopg import QueryCanceledError, TransactionRollbackError
try:
from psycopg2._psycopg import set_wait_callback, get_wait_callback
except ImportError:
pass
"""Isolation level values.""" """Isolation level values."""
ISOLATION_LEVEL_AUTOCOMMIT = 0 ISOLATION_LEVEL_AUTOCOMMIT = 0
@ -78,6 +73,7 @@ ISOLATION_LEVEL_READ_COMMITTED = 1
ISOLATION_LEVEL_REPEATABLE_READ = 2 ISOLATION_LEVEL_REPEATABLE_READ = 2
ISOLATION_LEVEL_SERIALIZABLE = 3 ISOLATION_LEVEL_SERIALIZABLE = 3
"""psycopg connection status values.""" """psycopg connection status values."""
STATUS_SETUP = 0 STATUS_SETUP = 0
STATUS_READY = 1 STATUS_READY = 1
@ -89,12 +85,14 @@ STATUS_PREPARED = 5
# This is a useful mnemonic to check if the connection is in a transaction # This is a useful mnemonic to check if the connection is in a transaction
STATUS_IN_TRANSACTION = STATUS_BEGIN STATUS_IN_TRANSACTION = STATUS_BEGIN
"""psycopg asynchronous connection polling values""" """psycopg asynchronous connection polling values"""
POLL_OK = 0 POLL_OK = 0
POLL_READ = 1 POLL_READ = 1
POLL_WRITE = 2 POLL_WRITE = 2
POLL_ERROR = 3 POLL_ERROR = 3
"""Backend transaction status values.""" """Backend transaction status values."""
TRANSACTION_STATUS_IDLE = 0 TRANSACTION_STATUS_IDLE = 0
TRANSACTION_STATUS_ACTIVE = 1 TRANSACTION_STATUS_ACTIVE = 1
@ -194,7 +192,7 @@ def _param_escape(s,
# Create default json typecasters for PostgreSQL 9.2 oids # Create default json typecasters for PostgreSQL 9.2 oids
from psycopg2._json import register_default_json, register_default_jsonb from psycopg2._json import register_default_json, register_default_jsonb # noqa
try: try:
JSON, JSONARRAY = register_default_json() JSON, JSONARRAY = register_default_json()
@ -206,7 +204,7 @@ del register_default_json, register_default_jsonb
# Create default Range typecasters # Create default Range typecasters
from psycopg2. _range import Range from psycopg2. _range import Range # noqa
del Range del Range

View File

@ -40,10 +40,23 @@ from psycopg2 import extensions as _ext
from psycopg2.extensions import cursor as _cursor from psycopg2.extensions import cursor as _cursor
from psycopg2.extensions import connection as _connection from psycopg2.extensions import connection as _connection
from psycopg2.extensions import adapt as _A, quote_ident from psycopg2.extensions import adapt as _A, quote_ident
from psycopg2._psycopg import REPLICATION_PHYSICAL, REPLICATION_LOGICAL
from psycopg2._psycopg import ReplicationConnection as _replicationConnection from psycopg2._psycopg import ( # noqa
from psycopg2._psycopg import ReplicationCursor as _replicationCursor REPLICATION_PHYSICAL, REPLICATION_LOGICAL,
from psycopg2._psycopg import ReplicationMessage ReplicationConnection as _replicationConnection,
ReplicationCursor as _replicationCursor,
ReplicationMessage)
# expose the json adaptation stuff into the module
from psycopg2._json import ( # noqa
json, Json, register_json, register_default_json, register_default_jsonb)
# Expose range-related objects
from psycopg2._range import ( # noqa
Range, NumericRange, DateRange, DateTimeRange, DateTimeTZRange,
register_range, RangeAdapter, RangeCaster)
class DictCursorBase(_cursor): class DictCursorBase(_cursor):
@ -109,6 +122,7 @@ class DictConnection(_connection):
kwargs.setdefault('cursor_factory', DictCursor) kwargs.setdefault('cursor_factory', DictCursor)
return super(DictConnection, self).cursor(*args, **kwargs) return super(DictConnection, self).cursor(*args, **kwargs)
class DictCursor(DictCursorBase): class DictCursor(DictCursorBase):
"""A cursor that keeps a list of column name -> index mappings.""" """A cursor that keeps a list of column name -> index mappings."""
@ -133,6 +147,7 @@ class DictCursor(DictCursorBase):
self.index[self.description[i][0]] = i self.index[self.description[i][0]] = i
self._query_executed = 0 self._query_executed = 0
class DictRow(list): class DictRow(list):
"""A row object that allow by-column-name access to data.""" """A row object that allow by-column-name access to data."""
@ -195,10 +210,10 @@ class DictRow(list):
# drop the crusty Py2 methods # drop the crusty Py2 methods
if _sys.version_info[0] > 2: if _sys.version_info[0] > 2:
items = iteritems; del iteritems items = iteritems # noqa
keys = iterkeys; del iterkeys keys = iterkeys # noqa
values = itervalues; del itervalues values = itervalues # noqa
del has_key del iteritems, iterkeys, itervalues, has_key
class RealDictConnection(_connection): class RealDictConnection(_connection):
@ -207,6 +222,7 @@ class RealDictConnection(_connection):
kwargs.setdefault('cursor_factory', RealDictCursor) kwargs.setdefault('cursor_factory', RealDictCursor)
return super(RealDictConnection, self).cursor(*args, **kwargs) return super(RealDictConnection, self).cursor(*args, **kwargs)
class RealDictCursor(DictCursorBase): class RealDictCursor(DictCursorBase):
"""A cursor that uses a real dict as the base type for rows. """A cursor that uses a real dict as the base type for rows.
@ -236,6 +252,7 @@ class RealDictCursor(DictCursorBase):
self.column_mapping.append(self.description[i][0]) self.column_mapping.append(self.description[i][0])
self._query_executed = 0 self._query_executed = 0
class RealDictRow(dict): class RealDictRow(dict):
"""A `!dict` subclass representing a data record.""" """A `!dict` subclass representing a data record."""
@ -268,6 +285,7 @@ class NamedTupleConnection(_connection):
kwargs.setdefault('cursor_factory', NamedTupleCursor) kwargs.setdefault('cursor_factory', NamedTupleCursor)
return super(NamedTupleConnection, self).cursor(*args, **kwargs) return super(NamedTupleConnection, self).cursor(*args, **kwargs)
class NamedTupleCursor(_cursor): class NamedTupleCursor(_cursor):
"""A cursor that generates results as `~collections.namedtuple`. """A cursor that generates results as `~collections.namedtuple`.
@ -372,11 +390,13 @@ class LoggingConnection(_connection):
def _logtofile(self, msg, curs): def _logtofile(self, msg, curs):
msg = self.filter(msg, curs) msg = self.filter(msg, curs)
if msg: self._logobj.write(msg + _os.linesep) if msg:
self._logobj.write(msg + _os.linesep)
def _logtologger(self, msg, curs): def _logtologger(self, msg, curs):
msg = self.filter(msg, curs) msg = self.filter(msg, curs)
if msg: self._logobj.debug(msg) if msg:
self._logobj.debug(msg)
def _check(self): def _check(self):
if not hasattr(self, '_logobj'): if not hasattr(self, '_logobj'):
@ -388,6 +408,7 @@ class LoggingConnection(_connection):
kwargs.setdefault('cursor_factory', LoggingCursor) kwargs.setdefault('cursor_factory', LoggingCursor)
return super(LoggingConnection, self).cursor(*args, **kwargs) return super(LoggingConnection, self).cursor(*args, **kwargs)
class LoggingCursor(_cursor): class LoggingCursor(_cursor):
"""A cursor that logs queries using its connection logging facilities.""" """A cursor that logs queries using its connection logging facilities."""
@ -428,6 +449,7 @@ class MinTimeLoggingConnection(LoggingConnection):
kwargs.setdefault('cursor_factory', MinTimeLoggingCursor) kwargs.setdefault('cursor_factory', MinTimeLoggingCursor)
return LoggingConnection.cursor(self, *args, **kwargs) return LoggingConnection.cursor(self, *args, **kwargs)
class MinTimeLoggingCursor(LoggingCursor): class MinTimeLoggingCursor(LoggingCursor):
"""The cursor sub-class companion to `MinTimeLoggingConnection`.""" """The cursor sub-class companion to `MinTimeLoggingConnection`."""
@ -479,18 +501,23 @@ class ReplicationCursor(_replicationCursor):
if slot_type == REPLICATION_LOGICAL: if slot_type == REPLICATION_LOGICAL:
if output_plugin is None: if output_plugin is None:
raise psycopg2.ProgrammingError("output plugin name is required to create logical replication slot") raise psycopg2.ProgrammingError(
"output plugin name is required to create "
"logical replication slot")
command += "LOGICAL %s" % quote_ident(output_plugin, self) command += "LOGICAL %s" % quote_ident(output_plugin, self)
elif slot_type == REPLICATION_PHYSICAL: elif slot_type == REPLICATION_PHYSICAL:
if output_plugin is not None: if output_plugin is not None:
raise psycopg2.ProgrammingError("cannot specify output plugin name when creating physical replication slot") raise psycopg2.ProgrammingError(
"cannot specify output plugin name when creating "
"physical replication slot")
command += "PHYSICAL" command += "PHYSICAL"
else: else:
raise psycopg2.ProgrammingError("unrecognized replication type: %s" % repr(slot_type)) raise psycopg2.ProgrammingError(
"unrecognized replication type: %s" % repr(slot_type))
self.execute(command) self.execute(command)
@ -513,7 +540,8 @@ class ReplicationCursor(_replicationCursor):
if slot_name: if slot_name:
command += "SLOT %s " % quote_ident(slot_name, self) command += "SLOT %s " % quote_ident(slot_name, self)
else: else:
raise psycopg2.ProgrammingError("slot name is required for logical replication") raise psycopg2.ProgrammingError(
"slot name is required for logical replication")
command += "LOGICAL " command += "LOGICAL "
@ -523,28 +551,32 @@ class ReplicationCursor(_replicationCursor):
# don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX # don't add "PHYSICAL", before 9.4 it was just START_REPLICATION XXX/XXX
else: else:
raise psycopg2.ProgrammingError("unrecognized replication type: %s" % repr(slot_type)) raise psycopg2.ProgrammingError(
"unrecognized replication type: %s" % repr(slot_type))
if type(start_lsn) is str: if type(start_lsn) is str:
lsn = start_lsn.split('/') lsn = start_lsn.split('/')
lsn = "%X/%08X" % (int(lsn[0], 16), int(lsn[1], 16)) lsn = "%X/%08X" % (int(lsn[0], 16), int(lsn[1], 16))
else: else:
lsn = "%X/%08X" % ((start_lsn >> 32) & 0xFFFFFFFF, start_lsn & 0xFFFFFFFF) lsn = "%X/%08X" % ((start_lsn >> 32) & 0xFFFFFFFF,
start_lsn & 0xFFFFFFFF)
command += lsn command += lsn
if timeline != 0: if timeline != 0:
if slot_type == REPLICATION_LOGICAL: if slot_type == REPLICATION_LOGICAL:
raise psycopg2.ProgrammingError("cannot specify timeline for logical replication") raise psycopg2.ProgrammingError(
"cannot specify timeline for logical replication")
command += " TIMELINE %d" % timeline command += " TIMELINE %d" % timeline
if options: if options:
if slot_type == REPLICATION_PHYSICAL: if slot_type == REPLICATION_PHYSICAL:
raise psycopg2.ProgrammingError("cannot specify output plugin options for physical replication") raise psycopg2.ProgrammingError(
"cannot specify output plugin options for physical replication")
command += " (" command += " ("
for k,v in options.iteritems(): for k, v in options.iteritems():
if not command.endswith('('): if not command.endswith('('):
command += ", " command += ", "
command += "%s %s" % (quote_ident(k, self), _A(str(v))) command += "%s %s" % (quote_ident(k, self), _A(str(v)))
@ -579,6 +611,7 @@ class UUID_adapter(object):
def __str__(self): def __str__(self):
return "'%s'::uuid" % self._uuid return "'%s'::uuid" % self._uuid
def register_uuid(oids=None, conn_or_curs=None): def register_uuid(oids=None, conn_or_curs=None):
"""Create the UUID type and an uuid.UUID adapter. """Create the UUID type and an uuid.UUID adapter.
@ -643,6 +676,7 @@ class Inet(object):
def __str__(self): def __str__(self):
return str(self.addr) return str(self.addr)
def register_inet(oid=None, conn_or_curs=None): def register_inet(oid=None, conn_or_curs=None):
"""Create the INET type and an Inet adapter. """Create the INET type and an Inet adapter.
@ -862,6 +896,7 @@ WHERE typname = 'hstore';
return tuple(rv0), tuple(rv1) return tuple(rv0), tuple(rv1)
def register_hstore(conn_or_curs, globally=False, unicode=False, def register_hstore(conn_or_curs, globally=False, unicode=False,
oid=None, array_oid=None): oid=None, array_oid=None):
"""Register adapter and typecaster for `!dict`\-\ |hstore| conversions. """Register adapter and typecaster for `!dict`\-\ |hstore| conversions.
@ -942,8 +977,8 @@ class CompositeCaster(object):
self.oid = oid self.oid = oid
self.array_oid = array_oid self.array_oid = array_oid
self.attnames = [ a[0] for a in attrs ] self.attnames = [a[0] for a in attrs]
self.atttypes = [ a[1] for a in attrs ] self.atttypes = [a[1] for a in attrs]
self._create_type(name, self.attnames) self._create_type(name, self.attnames)
self.typecaster = _ext.new_type((oid,), name, self.parse) self.typecaster = _ext.new_type((oid,), name, self.parse)
if array_oid: if array_oid:
@ -962,8 +997,8 @@ class CompositeCaster(object):
"expecting %d components for the type %s, %d found instead" % "expecting %d components for the type %s, %d found instead" %
(len(self.atttypes), self.name, len(tokens))) (len(self.atttypes), self.name, len(tokens)))
values = [ curs.cast(oid, token) values = [curs.cast(oid, token)
for oid, token in zip(self.atttypes, tokens) ] for oid, token in zip(self.atttypes, tokens)]
return self.make(values) return self.make(values)
@ -1057,11 +1092,12 @@ ORDER BY attnum;
type_oid = recs[0][0] type_oid = recs[0][0]
array_oid = recs[0][1] array_oid = recs[0][1]
type_attrs = [ (r[2], r[3]) for r in recs ] type_attrs = [(r[2], r[3]) for r in recs]
return self(tname, type_oid, type_attrs, return self(tname, type_oid, type_attrs,
array_oid=array_oid, schema=schema) array_oid=array_oid, schema=schema)
def register_composite(name, conn_or_curs, globally=False, factory=None): def register_composite(name, conn_or_curs, globally=False, factory=None):
"""Register a typecaster to convert a composite type into a tuple. """Register a typecaster to convert a composite type into a tuple.
@ -1084,17 +1120,7 @@ def register_composite(name, conn_or_curs, globally=False, factory=None):
_ext.register_type(caster.typecaster, not globally and conn_or_curs or None) _ext.register_type(caster.typecaster, not globally and conn_or_curs or None)
if caster.array_typecaster is not None: if caster.array_typecaster is not None:
_ext.register_type(caster.array_typecaster, not globally and conn_or_curs or None) _ext.register_type(
caster.array_typecaster, not globally and conn_or_curs or None)
return caster return caster
# expose the json adaptation stuff into the module
from psycopg2._json import json, Json, register_json
from psycopg2._json import register_default_json, register_default_jsonb
# Expose range-related objects
from psycopg2._range import Range, NumericRange
from psycopg2._range import DateRange, DateTimeRange, DateTimeTZRange
from psycopg2._range import register_range, RangeAdapter, RangeCaster

View File

@ -74,8 +74,10 @@ class AbstractConnectionPool(object):
def _getconn(self, key=None): def _getconn(self, key=None):
"""Get a free connection and assign it to 'key' if not None.""" """Get a free connection and assign it to 'key' if not None."""
if self.closed: raise PoolError("connection pool is closed") if self.closed:
if key is None: key = self._getkey() raise PoolError("connection pool is closed")
if key is None:
key = self._getkey()
if key in self._used: if key in self._used:
return self._used[key] return self._used[key]
@ -91,8 +93,10 @@ class AbstractConnectionPool(object):
def _putconn(self, conn, key=None, close=False): def _putconn(self, conn, key=None, close=False):
"""Put away a connection.""" """Put away a connection."""
if self.closed: raise PoolError("connection pool is closed") if self.closed:
if key is None: key = self._rused.get(id(conn)) raise PoolError("connection pool is closed")
if key is None:
key = self._rused.get(id(conn))
if not key: if not key:
raise PoolError("trying to put unkeyed connection") raise PoolError("trying to put unkeyed connection")
@ -129,7 +133,8 @@ class AbstractConnectionPool(object):
an already closed connection. If you call .closeall() make sure an already closed connection. If you call .closeall() make sure
your code can deal with it. your code can deal with it.
""" """
if self.closed: raise PoolError("connection pool is closed") if self.closed:
raise PoolError("connection pool is closed")
for conn in self._pool + list(self._used.values()): for conn in self._pool + list(self._used.values()):
try: try:
conn.close() conn.close()
@ -221,7 +226,8 @@ class PersistentConnectionPool(AbstractConnectionPool):
key = self.__thread.get_ident() key = self.__thread.get_ident()
self._lock.acquire() self._lock.acquire()
try: try:
if not conn: conn = self._used[key] if not conn:
conn = self._used[key]
self._putconn(conn, key, close) self._putconn(conn, key, close)
finally: finally:
self._lock.release() self._lock.release()

View File

@ -28,14 +28,15 @@ old code while porting to psycopg 2. Import it as follows::
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details. # License for more details.
import psycopg2._psycopg as _2psycopg import psycopg2._psycopg as _2psycopg # noqa
from psycopg2.extensions import cursor as _2cursor from psycopg2.extensions import cursor as _2cursor
from psycopg2.extensions import connection as _2connection from psycopg2.extensions import connection as _2connection
from psycopg2 import * from psycopg2 import * # noqa
import psycopg2.extensions as _ext import psycopg2.extensions as _ext
_2connect = connect _2connect = connect
def connect(*args, **kwargs): def connect(*args, **kwargs):
"""connect(dsn, ...) -> new psycopg 1.1.x compatible connection object""" """connect(dsn, ...) -> new psycopg 1.1.x compatible connection object"""
kwargs['connection_factory'] = connection kwargs['connection_factory'] = connection
@ -43,6 +44,7 @@ def connect(*args, **kwargs):
conn.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED) conn.set_isolation_level(_ext.ISOLATION_LEVEL_READ_COMMITTED)
return conn return conn
class connection(_2connection): class connection(_2connection):
"""psycopg 1.1.x connection.""" """psycopg 1.1.x connection."""
@ -92,4 +94,3 @@ class cursor(_2cursor):
for row in rows: for row in rows:
res.append(self.__build_dict(row)) res.append(self.__build_dict(row))
return res return res

View File

@ -31,6 +31,7 @@ import time
ZERO = datetime.timedelta(0) ZERO = datetime.timedelta(0)
class FixedOffsetTimezone(datetime.tzinfo): class FixedOffsetTimezone(datetime.tzinfo):
"""Fixed offset in minutes east from UTC. """Fixed offset in minutes east from UTC.
@ -52,7 +53,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
def __init__(self, offset=None, name=None): def __init__(self, offset=None, name=None):
if offset is not None: if offset is not None:
self._offset = datetime.timedelta(minutes = offset) self._offset = datetime.timedelta(minutes=offset)
if name is not None: if name is not None:
self._name = name self._name = name
@ -85,7 +86,7 @@ class FixedOffsetTimezone(datetime.tzinfo):
else: else:
seconds = self._offset.seconds + self._offset.days * 86400 seconds = self._offset.seconds + self._offset.days * 86400
hours, seconds = divmod(seconds, 3600) hours, seconds = divmod(seconds, 3600)
minutes = seconds/60 minutes = seconds / 60
if minutes: if minutes:
return "%+03d:%d" % (hours, minutes) return "%+03d:%d" % (hours, minutes)
else: else:
@ -95,13 +96,14 @@ class FixedOffsetTimezone(datetime.tzinfo):
return ZERO return ZERO
STDOFFSET = datetime.timedelta(seconds = -time.timezone) STDOFFSET = datetime.timedelta(seconds=-time.timezone)
if time.daylight: if time.daylight:
DSTOFFSET = datetime.timedelta(seconds = -time.altzone) DSTOFFSET = datetime.timedelta(seconds=-time.altzone)
else: else:
DSTOFFSET = STDOFFSET DSTOFFSET = STDOFFSET
DSTDIFF = DSTOFFSET - STDOFFSET DSTDIFF = DSTOFFSET - STDOFFSET
class LocalTimezone(datetime.tzinfo): class LocalTimezone(datetime.tzinfo):
"""Platform idea of local timezone. """Platform idea of local timezone.

View File

@ -19,8 +19,8 @@
# code defines the DBAPITypeObject fundamental types and warns for # code defines the DBAPITypeObject fundamental types and warns for
# undefined types. # undefined types.
import sys, os, string, copy import sys
from string import split, join, strip from string import split, strip
# here is the list of the foundamental types we want to import from # here is the list of the foundamental types we want to import from
@ -73,8 +73,7 @@ FOOTER = """ {NULL, NULL, NULL, NULL}\n};\n"""
# useful error reporting function # useful error reporting function
def error(msg): def error(msg):
"""Report an error on stderr.""" """Report an error on stderr."""
sys.stderr.write(msg+'\n') sys.stderr.write(msg + '\n')
# read couples from stdin and build list # read couples from stdin and build list
read_types = [] read_types = []
@ -91,14 +90,14 @@ for t in basic_types:
for v in t[1]: for v in t[1]:
found = filter(lambda x, y=v: x[0] == y, read_types) found = filter(lambda x, y=v: x[0] == y, read_types)
if len(found) == 0: if len(found) == 0:
error(v+': value not found') error(v + ': value not found')
elif len(found) > 1: elif len(found) > 1:
error(v+': too many values') error(v + ': too many values')
else: else:
found_types[k].append(int(found[0][1])) found_types[k].append(int(found[0][1]))
# now outputs to stdout the right C-style definitions # now outputs to stdout the right C-style definitions
stypes = "" ; sstruct = "" stypes = sstruct = ""
for t in basic_types: for t in basic_types:
k = t[0] k = t[0]
s = str(found_types[k]) s = str(found_types[k])
@ -108,7 +107,7 @@ for t in basic_types:
% (k, k, k)) % (k, k, k))
for t in array_types: for t in array_types:
kt = t[0] kt = t[0]
ka = t[0]+'ARRAY' ka = t[0] + 'ARRAY'
s = str(t[1]) s = str(t[1])
s = '{' + s[1:-1] + ', 0}' s = '{' + s[1:-1] + ', 0}'
stypes = stypes + ('static long int typecast_%s_types[] = %s;\n' % (ka, s)) stypes = stypes + ('static long int typecast_%s_types[] = %s;\n' % (ka, s))

View File

@ -23,6 +23,7 @@ from collections import defaultdict
from BeautifulSoup import BeautifulSoup as BS from BeautifulSoup import BeautifulSoup as BS
def main(): def main():
if len(sys.argv) != 2: if len(sys.argv) != 2:
print >>sys.stderr, "usage: %s /path/to/errorcodes.py" % sys.argv[0] print >>sys.stderr, "usage: %s /path/to/errorcodes.py" % sys.argv[0]
@ -41,6 +42,7 @@ def main():
for line in generate_module_data(classes, errors): for line in generate_module_data(classes, errors):
print >>f, line print >>f, line
def read_base_file(filename): def read_base_file(filename):
rv = [] rv = []
for line in open(filename): for line in open(filename):
@ -50,6 +52,7 @@ def read_base_file(filename):
raise ValueError("can't find the separator. Is this the right file?") raise ValueError("can't find the separator. Is this the right file?")
def parse_errors_txt(url): def parse_errors_txt(url):
classes = {} classes = {}
errors = defaultdict(dict) errors = defaultdict(dict)
@ -84,6 +87,7 @@ def parse_errors_txt(url):
return classes, errors return classes, errors
def parse_errors_sgml(url): def parse_errors_sgml(url):
page = BS(urllib2.urlopen(url)) page = BS(urllib2.urlopen(url))
table = page('table')[1]('tbody')[0] table = page('table')[1]('tbody')[0]
@ -130,6 +134,7 @@ errors_txt_url = \
"http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob_plain;" \ "http://git.postgresql.org/gitweb/?p=postgresql.git;a=blob_plain;" \
"f=src/backend/utils/errcodes.txt;hb=REL%s_STABLE" "f=src/backend/utils/errcodes.txt;hb=REL%s_STABLE"
def fetch_errors(versions): def fetch_errors(versions):
classes = {} classes = {}
errors = defaultdict(dict) errors = defaultdict(dict)
@ -148,6 +153,7 @@ def fetch_errors(versions):
return classes, errors return classes, errors
def generate_module_data(classes, errors): def generate_module_data(classes, errors):
yield "" yield ""
yield "# Error classes" yield "# Error classes"
@ -163,7 +169,6 @@ def generate_module_data(classes, errors):
for errcode, errlabel in sorted(errors[clscode].items()): for errcode, errlabel in sorted(errors[clscode].items()):
yield "%s = %r" % (errlabel, errcode) yield "%s = %r" % (errlabel, errcode)
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(main()) sys.exit(main())

View File

@ -25,6 +25,7 @@ import unittest
from pprint import pprint from pprint import pprint
from collections import defaultdict from collections import defaultdict
def main(): def main():
opt = parse_args() opt = parse_args()
@ -58,6 +59,7 @@ def main():
return rv return rv
def parse_args(): def parse_args():
import optparse import optparse
@ -83,7 +85,7 @@ def dump(i, opt):
c[type(o)] += 1 c[type(o)] += 1
pprint( pprint(
sorted(((v,str(k)) for k,v in c.items()), reverse=True), sorted(((v, str(k)) for k, v in c.items()), reverse=True),
stream=open("debug-%02d.txt" % i, "w")) stream=open("debug-%02d.txt" % i, "w"))
if opt.objs: if opt.objs:
@ -95,7 +97,7 @@ def dump(i, opt):
# TODO: very incomplete # TODO: very incomplete
if t is dict: if t is dict:
co.sort(key = lambda d: d.items()) co.sort(key=lambda d: d.items())
else: else:
co.sort() co.sort()
@ -104,4 +106,3 @@ def dump(i, opt):
if __name__ == '__main__': if __name__ == '__main__':
sys.exit(main()) sys.exit(main())

View File

@ -25,34 +25,9 @@ UPDATEs. psycopg2 also provide full asynchronous operations and support
for coroutine libraries. for coroutine libraries.
""" """
# note: if you are changing the list of supported Python version please fix
# the docs in install.rst and the /features/ page on the website.
classifiers = """\
Development Status :: 5 - Production/Stable
Intended Audience :: Developers
License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)
License :: OSI Approved :: Zope Public License
Programming Language :: Python
Programming Language :: Python :: 2.6
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.1
Programming Language :: Python :: 3.2
Programming Language :: Python :: 3.3
Programming Language :: Python :: 3.4
Programming Language :: Python :: 3.5
Programming Language :: C
Programming Language :: SQL
Topic :: Database
Topic :: Database :: Front-Ends
Topic :: Software Development
Topic :: Software Development :: Libraries :: Python Modules
Operating System :: Microsoft :: Windows
Operating System :: Unix
"""
# Note: The setup.py must be compatible with both Python 2 and 3 # Note: The setup.py must be compatible with both Python 2 and 3
import os import os
import sys import sys
import re import re
@ -87,6 +62,33 @@ except ImportError:
PSYCOPG_VERSION = '2.7.dev0' PSYCOPG_VERSION = '2.7.dev0'
# note: if you are changing the list of supported Python version please fix
# the docs in install.rst and the /features/ page on the website.
classifiers = """\
Development Status :: 5 - Production/Stable
Intended Audience :: Developers
License :: OSI Approved :: GNU Library or Lesser General Public License (LGPL)
License :: OSI Approved :: Zope Public License
Programming Language :: Python
Programming Language :: Python :: 2.6
Programming Language :: Python :: 2.7
Programming Language :: Python :: 3
Programming Language :: Python :: 3.1
Programming Language :: Python :: 3.2
Programming Language :: Python :: 3.3
Programming Language :: Python :: 3.4
Programming Language :: Python :: 3.5
Programming Language :: C
Programming Language :: SQL
Topic :: Database
Topic :: Database :: Front-Ends
Topic :: Software Development
Topic :: Software Development :: Libraries :: Python Modules
Operating System :: Microsoft :: Windows
Operating System :: Unix
"""
version_flags = ['dt', 'dec'] version_flags = ['dt', 'dec']
PLATFORM_IS_WINDOWS = sys.platform.lower().startswith('win') PLATFORM_IS_WINDOWS = sys.platform.lower().startswith('win')
@ -445,6 +447,7 @@ class psycopg_build_ext(build_ext):
if hasattr(self, "finalize_" + sys.platform): if hasattr(self, "finalize_" + sys.platform):
getattr(self, "finalize_" + sys.platform)() getattr(self, "finalize_" + sys.platform)()
def is_py_64(): def is_py_64():
# sys.maxint not available since Py 3.1; # sys.maxint not available since Py 3.1;
# sys.maxsize not available before Py 2.6; # sys.maxsize not available before Py 2.6;
@ -580,8 +583,8 @@ for define in parser.get('build_ext', 'define').split(','):
# build the extension # build the extension
sources = [ os.path.join('psycopg', x) for x in sources] sources = [os.path.join('psycopg', x) for x in sources]
depends = [ os.path.join('psycopg', x) for x in depends] depends = [os.path.join('psycopg', x) for x in depends]
ext.append(Extension("psycopg2._psycopg", sources, ext.append(Extension("psycopg2._psycopg", sources,
define_macros=define_macros, define_macros=define_macros,

View File

@ -52,6 +52,7 @@ if sys.version_info[:2] >= (2, 5):
else: else:
test_with = None test_with = None
def test_suite(): def test_suite():
# If connection to test db fails, bail out early. # If connection to test db fails, bail out early.
import psycopg2 import psycopg2

View File

@ -33,6 +33,7 @@ import StringIO
from testutils import ConnectingTestCase from testutils import ConnectingTestCase
class PollableStub(object): class PollableStub(object):
"""A 'pollable' wrapper allowing analysis of the `poll()` calls.""" """A 'pollable' wrapper allowing analysis of the `poll()` calls."""
def __init__(self, pollable): def __init__(self, pollable):
@ -68,6 +69,7 @@ class AsyncTests(ConnectingTestCase):
def test_connection_setup(self): def test_connection_setup(self):
cur = self.conn.cursor() cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor() sync_cur = self.sync_conn.cursor()
del cur, sync_cur
self.assert_(self.conn.async) self.assert_(self.conn.async)
self.assert_(not self.sync_conn.async) self.assert_(not self.sync_conn.async)
@ -77,7 +79,7 @@ class AsyncTests(ConnectingTestCase):
# check other properties to be found on the connection # check other properties to be found on the connection
self.assert_(self.conn.server_version) self.assert_(self.conn.server_version)
self.assert_(self.conn.protocol_version in (2,3)) self.assert_(self.conn.protocol_version in (2, 3))
self.assert_(self.conn.encoding in psycopg2.extensions.encodings) self.assert_(self.conn.encoding in psycopg2.extensions.encodings)
def test_async_named_cursor(self): def test_async_named_cursor(self):
@ -108,6 +110,7 @@ class AsyncTests(ConnectingTestCase):
def test_async_after_async(self): def test_async_after_async(self):
cur = self.conn.cursor() cur = self.conn.cursor()
cur2 = self.conn.cursor() cur2 = self.conn.cursor()
del cur2
cur.execute("insert into table1 values (1)") cur.execute("insert into table1 values (1)")
@ -422,14 +425,14 @@ class AsyncTests(ConnectingTestCase):
def test_async_cursor_gone(self): def test_async_cursor_gone(self):
import gc import gc
cur = self.conn.cursor() cur = self.conn.cursor()
cur.execute("select 42;"); cur.execute("select 42;")
del cur del cur
gc.collect() gc.collect()
self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn) self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn)
# The connection is still usable # The connection is still usable
cur = self.conn.cursor() cur = self.conn.cursor()
cur.execute("select 42;"); cur.execute("select 42;")
self.wait(self.conn) self.wait(self.conn)
self.assertEqual(cur.fetchone(), (42,)) self.assertEqual(cur.fetchone(), (42,))
@ -449,4 +452,3 @@ def test_suite():
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -26,15 +26,17 @@ import psycopg2
import time import time
import unittest import unittest
class DateTimeAllocationBugTestCase(unittest.TestCase): class DateTimeAllocationBugTestCase(unittest.TestCase):
def test_date_time_allocation_bug(self): def test_date_time_allocation_bug(self):
d1 = psycopg2.Date(2002,12,25) d1 = psycopg2.Date(2002, 12, 25)
d2 = psycopg2.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) d2 = psycopg2.DateFromTicks(time.mktime((2002, 12, 25, 0, 0, 0, 0, 0, 0)))
t1 = psycopg2.Time(13,45,30) t1 = psycopg2.Time(13, 45, 30)
t2 = psycopg2.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) t2 = psycopg2.TimeFromTicks(time.mktime((2001, 1, 1, 13, 45, 30, 0, 0, 0)))
t1 = psycopg2.Timestamp(2002,12,25,13,45,30) t1 = psycopg2.Timestamp(2002, 12, 25, 13, 45, 30)
t2 = psycopg2.TimestampFromTicks( t2 = psycopg2.TimestampFromTicks(
time.mktime((2002,12,25,13,45,30,0,0,0))) time.mktime((2002, 12, 25, 13, 45, 30, 0, 0, 0)))
del d1, d2, t1, t2
def test_suite(): def test_suite():

View File

@ -29,6 +29,7 @@ import gc
from testutils import ConnectingTestCase, skip_if_no_uuid from testutils import ConnectingTestCase, skip_if_no_uuid
class StolenReferenceTestCase(ConnectingTestCase): class StolenReferenceTestCase(ConnectingTestCase):
@skip_if_no_uuid @skip_if_no_uuid
def test_stolen_reference_bug(self): def test_stolen_reference_bug(self):
@ -41,8 +42,10 @@ class StolenReferenceTestCase(ConnectingTestCase):
curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid") curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid")
curs.fetchone() curs.fetchone()
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -32,6 +32,7 @@ from psycopg2 import extras
from testconfig import dsn from testconfig import dsn
from testutils import unittest, ConnectingTestCase, skip_before_postgres from testutils import unittest, ConnectingTestCase, skip_before_postgres
class CancelTests(ConnectingTestCase): class CancelTests(ConnectingTestCase):
def setUp(self): def setUp(self):
@ -71,6 +72,7 @@ class CancelTests(ConnectingTestCase):
except Exception, e: except Exception, e:
errors.append(e) errors.append(e)
raise raise
del cur
thread1 = threading.Thread(target=neverending, args=(self.conn, )) thread1 = threading.Thread(target=neverending, args=(self.conn, ))
# wait a bit to make sure that the other thread is already in # wait a bit to make sure that the other thread is already in

View File

@ -27,17 +27,16 @@ import sys
import time import time
import threading import threading
from operator import attrgetter from operator import attrgetter
from StringIO import StringIO
import psycopg2 import psycopg2
import psycopg2.errorcodes import psycopg2.errorcodes
import psycopg2.extensions from psycopg2 import extensions as ext
ext = psycopg2.extensions
from testutils import (
unittest, decorate_all_tests, skip_if_no_superuser,
skip_before_postgres, skip_after_postgres, skip_before_libpq,
ConnectingTestCase, skip_if_tpc_disabled, skip_if_windows)
from testutils import unittest, decorate_all_tests, skip_if_no_superuser
from testutils import skip_before_postgres, skip_after_postgres, skip_before_libpq
from testutils import ConnectingTestCase, skip_if_tpc_disabled
from testutils import skip_if_windows
from testconfig import dsn, dbname from testconfig import dsn, dbname
@ -112,8 +111,14 @@ class ConnectionTests(ConnectingTestCase):
cur = conn.cursor() cur = conn.cursor()
if self.conn.server_version >= 90300: if self.conn.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
cur.execute("create temp table table1 (id serial); create temp table table2 (id serial);") cur.execute("""
cur.execute("create temp table table3 (id serial); create temp table table4 (id serial);") create temp table table1 (id serial);
create temp table table2 (id serial);
""")
cur.execute("""
create temp table table3 (id serial);
create temp table table4 (id serial);
""")
self.assertEqual(4, len(conn.notices)) self.assertEqual(4, len(conn.notices))
self.assert_('table1' in conn.notices[0]) self.assert_('table1' in conn.notices[0])
self.assert_('table2' in conn.notices[1]) self.assert_('table2' in conn.notices[1])
@ -126,7 +131,8 @@ class ConnectionTests(ConnectingTestCase):
if self.conn.server_version >= 90300: if self.conn.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
for i in range(0, 100, 10): for i in range(0, 100, 10):
sql = " ".join(["create temp table table%d (id serial);" % j for j in range(i, i + 10)]) sql = " ".join(["create temp table table%d (id serial);" % j
for j in range(i, i + 10)])
cur.execute(sql) cur.execute(sql)
self.assertEqual(50, len(conn.notices)) self.assertEqual(50, len(conn.notices))
@ -141,8 +147,13 @@ class ConnectionTests(ConnectingTestCase):
if self.conn.server_version >= 90300: if self.conn.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
cur.execute("create temp table table1 (id serial); create temp table table2 (id serial);") cur.execute("""
cur.execute("create temp table table3 (id serial); create temp table table4 (id serial);") create temp table table1 (id serial);
create temp table table2 (id serial);
""")
cur.execute("""
create temp table table3 (id serial);
create temp table table4 (id serial);""")
self.assertEqual(len(conn.notices), 4) self.assertEqual(len(conn.notices), 4)
self.assert_('table1' in conn.notices.popleft()) self.assert_('table1' in conn.notices.popleft())
self.assert_('table2' in conn.notices.popleft()) self.assert_('table2' in conn.notices.popleft())
@ -152,7 +163,8 @@ class ConnectionTests(ConnectingTestCase):
# not limited, but no error # not limited, but no error
for i in range(0, 100, 10): for i in range(0, 100, 10):
sql = " ".join(["create temp table table2_%d (id serial);" % j for j in range(i, i + 10)]) sql = " ".join(["create temp table table2_%d (id serial);" % j
for j in range(i, i + 10)])
cur.execute(sql) cur.execute(sql)
self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]), self.assertEqual(len([n for n in conn.notices if 'CREATE TABLE' in n]),
@ -315,14 +327,16 @@ class ParseDsnTestCase(ConnectingTestCase):
def test_parse_dsn(self): def test_parse_dsn(self):
from psycopg2 import ProgrammingError from psycopg2 import ProgrammingError
self.assertEqual(ext.parse_dsn('dbname=test user=tester password=secret'), self.assertEqual(
ext.parse_dsn('dbname=test user=tester password=secret'),
dict(user='tester', password='secret', dbname='test'), dict(user='tester', password='secret', dbname='test'),
"simple DSN parsed") "simple DSN parsed")
self.assertRaises(ProgrammingError, ext.parse_dsn, self.assertRaises(ProgrammingError, ext.parse_dsn,
"dbname=test 2 user=tester password=secret") "dbname=test 2 user=tester password=secret")
self.assertEqual(ext.parse_dsn("dbname='test 2' user=tester password=secret"), self.assertEqual(
ext.parse_dsn("dbname='test 2' user=tester password=secret"),
dict(user='tester', password='secret', dbname='test 2'), dict(user='tester', password='secret', dbname='test 2'),
"DSN with quoting parsed") "DSN with quoting parsed")
@ -485,7 +499,8 @@ class IsolationLevelsTestCase(ConnectingTestCase):
levels = [ levels = [
(None, psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT), (None, psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT),
('read uncommitted', psycopg2.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED), ('read uncommitted',
psycopg2.extensions.ISOLATION_LEVEL_READ_UNCOMMITTED),
('read committed', psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED), ('read committed', psycopg2.extensions.ISOLATION_LEVEL_READ_COMMITTED),
('repeatable read', psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ), ('repeatable read', psycopg2.extensions.ISOLATION_LEVEL_REPEATABLE_READ),
('serializable', psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE), ('serializable', psycopg2.extensions.ISOLATION_LEVEL_SERIALIZABLE),

View File

@ -41,6 +41,7 @@ if sys.version_info[0] < 3:
else: else:
from io import TextIOBase as _base from io import TextIOBase as _base
class MinimalRead(_base): class MinimalRead(_base):
"""A file wrapper exposing the minimal interface to copy from.""" """A file wrapper exposing the minimal interface to copy from."""
def __init__(self, f): def __init__(self, f):
@ -52,6 +53,7 @@ class MinimalRead(_base):
def readline(self): def readline(self):
return self.f.readline() return self.f.readline()
class MinimalWrite(_base): class MinimalWrite(_base):
"""A file wrapper exposing the minimal interface to copy to.""" """A file wrapper exposing the minimal interface to copy to."""
def __init__(self, f): def __init__(self, f):
@ -78,7 +80,7 @@ class CopyTests(ConnectingTestCase):
def test_copy_from(self): def test_copy_from(self):
curs = self.conn.cursor() curs = self.conn.cursor()
try: try:
self._copy_from(curs, nrecs=1024, srec=10*1024, copykw={}) self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
finally: finally:
curs.close() curs.close()
@ -86,8 +88,8 @@ class CopyTests(ConnectingTestCase):
# Trying to trigger a "would block" error # Trying to trigger a "would block" error
curs = self.conn.cursor() curs = self.conn.cursor()
try: try:
self._copy_from(curs, nrecs=10*1024, srec=10*1024, self._copy_from(curs, nrecs=10 * 1024, srec=10 * 1024,
copykw={'size': 20*1024*1024}) copykw={'size': 20 * 1024 * 1024})
finally: finally:
curs.close() curs.close()
@ -110,6 +112,7 @@ class CopyTests(ConnectingTestCase):
f.write("%s\n" % (i,)) f.write("%s\n" % (i,))
f.seek(0) f.seek(0)
def cols(): def cols():
raise ZeroDivisionError() raise ZeroDivisionError()
yield 'id' yield 'id'
@ -120,8 +123,8 @@ class CopyTests(ConnectingTestCase):
def test_copy_to(self): def test_copy_to(self):
curs = self.conn.cursor() curs = self.conn.cursor()
try: try:
self._copy_from(curs, nrecs=1024, srec=10*1024, copykw={}) self._copy_from(curs, nrecs=1024, srec=10 * 1024, copykw={})
self._copy_to(curs, srec=10*1024) self._copy_to(curs, srec=10 * 1024)
finally: finally:
curs.close() curs.close()
@ -209,9 +212,11 @@ class CopyTests(ConnectingTestCase):
exp_size = 123 exp_size = 123
# hack here to leave file as is, only check size when reading # hack here to leave file as is, only check size when reading
real_read = f.read real_read = f.read
def read(_size, f=f, exp_size=exp_size): def read(_size, f=f, exp_size=exp_size):
self.assertEqual(_size, exp_size) self.assertEqual(_size, exp_size)
return real_read(_size) return real_read(_size)
f.read = read f.read = read
curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size) curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size)
curs.execute("select data from tcopy;") curs.execute("select data from tcopy;")
@ -221,7 +226,7 @@ class CopyTests(ConnectingTestCase):
f = StringIO() f = StringIO()
for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)): for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)):
l = c * srec l = c * srec
f.write("%s\t%s\n" % (i,l)) f.write("%s\t%s\n" % (i, l))
f.seek(0) f.seek(0)
curs.copy_from(MinimalRead(f), "tcopy", **copykw) curs.copy_from(MinimalRead(f), "tcopy", **copykw)
@ -258,20 +263,20 @@ class CopyTests(ConnectingTestCase):
curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f)
def test_copy_no_column_limit(self): def test_copy_no_column_limit(self):
cols = [ "c%050d" % i for i in range(200) ] cols = ["c%050d" % i for i in range(200)]
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join(
[ "%s int" % c for c in cols])) ["%s int" % c for c in cols]))
curs.execute("INSERT INTO manycols DEFAULT VALUES") curs.execute("INSERT INTO manycols DEFAULT VALUES")
f = StringIO() f = StringIO()
curs.copy_to(f, "manycols", columns = cols) curs.copy_to(f, "manycols", columns=cols)
f.seek(0) f.seek(0)
self.assertEqual(f.read().split(), ['\\N'] * len(cols)) self.assertEqual(f.read().split(), ['\\N'] * len(cols))
f.seek(0) f.seek(0)
curs.copy_from(f, "manycols", columns = cols) curs.copy_from(f, "manycols", columns=cols)
curs.execute("select count(*) from manycols;") curs.execute("select count(*) from manycols;")
self.assertEqual(curs.fetchone()[0], 2) self.assertEqual(curs.fetchone()[0], 2)
@ -316,7 +321,7 @@ try:
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
pass pass
conn.close() conn.close()
""" % { 'dsn': dsn,}) """ % {'dsn': dsn})
proc = Popen([sys.executable, '-c', script_to_py3(script)]) proc = Popen([sys.executable, '-c', script_to_py3(script)])
proc.communicate() proc.communicate()
@ -334,7 +339,7 @@ try:
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
pass pass
conn.close() conn.close()
""" % { 'dsn': dsn,}) """ % {'dsn': dsn})
proc = Popen([sys.executable, '-c', script_to_py3(script)], stdout=PIPE) proc = Popen([sys.executable, '-c', script_to_py3(script)], stdout=PIPE)
proc.communicate() proc.communicate()
@ -343,10 +348,10 @@ conn.close()
def test_copy_from_propagate_error(self): def test_copy_from_propagate_error(self):
class BrokenRead(_base): class BrokenRead(_base):
def read(self, size): def read(self, size):
return 1/0 return 1 / 0
def readline(self): def readline(self):
return 1/0 return 1 / 0
curs = self.conn.cursor() curs = self.conn.cursor()
# It seems we cannot do this, but now at least we propagate the error # It seems we cannot do this, but now at least we propagate the error
@ -360,7 +365,7 @@ conn.close()
def test_copy_to_propagate_error(self): def test_copy_to_propagate_error(self):
class BrokenWrite(_base): class BrokenWrite(_base):
def write(self, data): def write(self, data):
return 1/0 return 1 / 0
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("insert into tcopy values (10, 'hi')") curs.execute("insert into tcopy values (10, 'hi')")

View File

@ -29,6 +29,7 @@ import psycopg2.extensions
from testutils import unittest, ConnectingTestCase, skip_before_postgres from testutils import unittest, ConnectingTestCase, skip_before_postgres
from testutils import skip_if_no_namedtuple, skip_if_no_getrefcount from testutils import skip_if_no_namedtuple, skip_if_no_getrefcount
class CursorTests(ConnectingTestCase): class CursorTests(ConnectingTestCase):
def test_close_idempotent(self): def test_close_idempotent(self):
@ -47,8 +48,10 @@ class CursorTests(ConnectingTestCase):
conn = self.conn conn = self.conn
cur = conn.cursor() cur = conn.cursor()
cur.execute("create temp table test_exc (data int);") cur.execute("create temp table test_exc (data int);")
def buggygen(): def buggygen():
yield 1//0 yield 1 // 0
self.assertRaises(ZeroDivisionError, self.assertRaises(ZeroDivisionError,
cur.executemany, "insert into test_exc values (%s)", buggygen()) cur.executemany, "insert into test_exc values (%s)", buggygen())
cur.close() cur.close()
@ -102,8 +105,7 @@ class CursorTests(ConnectingTestCase):
# issue #81: reference leak when a parameter value is referenced # issue #81: reference leak when a parameter value is referenced
# more than once from a dict. # more than once from a dict.
cur = self.conn.cursor() cur = self.conn.cursor()
i = lambda x: x foo = (lambda x: x)('foo') * 10
foo = i('foo') * 10
import sys import sys
nref1 = sys.getrefcount(foo) nref1 = sys.getrefcount(foo)
cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo}) cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo})
@ -135,7 +137,7 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45')) self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45'))
from datetime import date from datetime import date
self.assertEqual(date(2011,1,2), curs.cast(1082, '2011-01-02')) self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02'))
self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown
def test_cast_specificity(self): def test_cast_specificity(self):
@ -158,7 +160,8 @@ class CursorTests(ConnectingTestCase):
curs = self.conn.cursor() curs = self.conn.cursor()
w = ref(curs) w = ref(curs)
del curs del curs
import gc; gc.collect() import gc
gc.collect()
self.assert_(w() is None) self.assert_(w() is None)
def test_null_name(self): def test_null_name(self):
@ -168,7 +171,7 @@ class CursorTests(ConnectingTestCase):
def test_invalid_name(self): def test_invalid_name(self):
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("create temp table invname (data int);") curs.execute("create temp table invname (data int);")
for i in (10,20,30): for i in (10, 20, 30):
curs.execute("insert into invname values (%s)", (i,)) curs.execute("insert into invname values (%s)", (i,))
curs.close() curs.close()
@ -193,16 +196,16 @@ class CursorTests(ConnectingTestCase):
self._create_withhold_table() self._create_withhold_table()
curs = self.conn.cursor("W") curs = self.conn.cursor("W")
self.assertEqual(curs.withhold, False); self.assertEqual(curs.withhold, False)
curs.withhold = True curs.withhold = True
self.assertEqual(curs.withhold, True); self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data") curs.execute("select data from withhold order by data")
self.conn.commit() self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs.close() curs.close()
curs = self.conn.cursor("W", withhold=True) curs = self.conn.cursor("W", withhold=True)
self.assertEqual(curs.withhold, True); self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data") curs.execute("select data from withhold order by data")
self.conn.commit() self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
@ -264,18 +267,18 @@ class CursorTests(ConnectingTestCase):
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("create table scrollable (data int)") curs.execute("create table scrollable (data int)")
curs.executemany("insert into scrollable values (%s)", curs.executemany("insert into scrollable values (%s)",
[ (i,) for i in range(100) ]) [(i,) for i in range(100)])
curs.close() curs.close()
for t in range(2): for t in range(2):
if not t: if not t:
curs = self.conn.cursor("S") curs = self.conn.cursor("S")
self.assertEqual(curs.scrollable, None); self.assertEqual(curs.scrollable, None)
curs.scrollable = True curs.scrollable = True
else: else:
curs = self.conn.cursor("S", scrollable=True) curs = self.conn.cursor("S", scrollable=True)
self.assertEqual(curs.scrollable, True); self.assertEqual(curs.scrollable, True)
curs.itersize = 10 curs.itersize = 10
# complex enough to make postgres cursors declare without # complex enough to make postgres cursors declare without
@ -303,7 +306,7 @@ class CursorTests(ConnectingTestCase):
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("create table scrollable (data int)") curs.execute("create table scrollable (data int)")
curs.executemany("insert into scrollable values (%s)", curs.executemany("insert into scrollable values (%s)",
[ (i,) for i in range(100) ]) [(i,) for i in range(100)])
curs.close() curs.close()
curs = self.conn.cursor("S") # default scrollability curs = self.conn.cursor("S") # default scrollability
@ -340,18 +343,18 @@ class CursorTests(ConnectingTestCase):
def test_iter_named_cursor_default_itersize(self): def test_iter_named_cursor_default_itersize(self):
curs = self.conn.cursor('tmp') curs = self.conn.cursor('tmp')
curs.execute('select generate_series(1,50)') curs.execute('select generate_series(1,50)')
rv = [ (r[0], curs.rownumber) for r in curs ] rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in one gulp # everything swallowed in one gulp
self.assertEqual(rv, [(i,i) for i in range(1,51)]) self.assertEqual(rv, [(i, i) for i in range(1, 51)])
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_itersize(self): def test_iter_named_cursor_itersize(self):
curs = self.conn.cursor('tmp') curs = self.conn.cursor('tmp')
curs.itersize = 30 curs.itersize = 30
curs.execute('select generate_series(1,50)') curs.execute('select generate_series(1,50)')
rv = [ (r[0], curs.rownumber) for r in curs ] rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in two gulps # everything swallowed in two gulps
self.assertEqual(rv, [(i,((i - 1) % 30) + 1) for i in range(1,51)]) self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)])
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_rownumber(self): def test_iter_named_cursor_rownumber(self):

View File

@ -27,6 +27,7 @@ import psycopg2
from psycopg2.tz import FixedOffsetTimezone, ZERO from psycopg2.tz import FixedOffsetTimezone, ZERO
from testutils import unittest, ConnectingTestCase, skip_before_postgres from testutils import unittest, ConnectingTestCase, skip_before_postgres
class CommonDatetimeTestsMixin: class CommonDatetimeTestsMixin:
def execute(self, *args): def execute(self, *args):
@ -269,32 +270,32 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date(self): def test_type_roundtrip_date(self):
from datetime import date from datetime import date
self._test_type_roundtrip(date(2010,5,3)) self._test_type_roundtrip(date(2010, 5, 3))
def test_type_roundtrip_datetime(self): def test_type_roundtrip_datetime(self):
from datetime import datetime from datetime import datetime
dt = self._test_type_roundtrip(datetime(2010,5,3,10,20,30)) dt = self._test_type_roundtrip(datetime(2010, 5, 3, 10, 20, 30))
self.assertEqual(None, dt.tzinfo) self.assertEqual(None, dt.tzinfo)
def test_type_roundtrip_datetimetz(self): def test_type_roundtrip_datetimetz(self):
from datetime import datetime from datetime import datetime
import psycopg2.tz import psycopg2.tz
tz = psycopg2.tz.FixedOffsetTimezone(8*60) tz = psycopg2.tz.FixedOffsetTimezone(8 * 60)
dt1 = datetime(2010,5,3,10,20,30, tzinfo=tz) dt1 = datetime(2010, 5, 3, 10, 20, 30, tzinfo=tz)
dt2 = self._test_type_roundtrip(dt1) dt2 = self._test_type_roundtrip(dt1)
self.assertNotEqual(None, dt2.tzinfo) self.assertNotEqual(None, dt2.tzinfo)
self.assertEqual(dt1, dt2) self.assertEqual(dt1, dt2)
def test_type_roundtrip_time(self): def test_type_roundtrip_time(self):
from datetime import time from datetime import time
tm = self._test_type_roundtrip(time(10,20,30)) tm = self._test_type_roundtrip(time(10, 20, 30))
self.assertEqual(None, tm.tzinfo) self.assertEqual(None, tm.tzinfo)
def test_type_roundtrip_timetz(self): def test_type_roundtrip_timetz(self):
from datetime import time from datetime import time
import psycopg2.tz import psycopg2.tz
tz = psycopg2.tz.FixedOffsetTimezone(8*60) tz = psycopg2.tz.FixedOffsetTimezone(8 * 60)
tm1 = time(10,20,30, tzinfo=tz) tm1 = time(10, 20, 30, tzinfo=tz)
tm2 = self._test_type_roundtrip(tm1) tm2 = self._test_type_roundtrip(tm1)
self.assertNotEqual(None, tm2.tzinfo) self.assertNotEqual(None, tm2.tzinfo)
self.assertEqual(tm1, tm2) self.assertEqual(tm1, tm2)
@ -305,15 +306,15 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date_array(self): def test_type_roundtrip_date_array(self):
from datetime import date from datetime import date
self._test_type_roundtrip_array(date(2010,5,3)) self._test_type_roundtrip_array(date(2010, 5, 3))
def test_type_roundtrip_datetime_array(self): def test_type_roundtrip_datetime_array(self):
from datetime import datetime from datetime import datetime
self._test_type_roundtrip_array(datetime(2010,5,3,10,20,30)) self._test_type_roundtrip_array(datetime(2010, 5, 3, 10, 20, 30))
def test_type_roundtrip_time_array(self): def test_type_roundtrip_time_array(self):
from datetime import time from datetime import time
self._test_type_roundtrip_array(time(10,20,30)) self._test_type_roundtrip_array(time(10, 20, 30))
def test_type_roundtrip_interval_array(self): def test_type_roundtrip_interval_array(self):
from datetime import timedelta from datetime import timedelta
@ -355,8 +356,10 @@ class mxDateTimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
psycopg2.extensions.register_type(self.INTERVAL, self.conn) psycopg2.extensions.register_type(self.INTERVAL, self.conn)
psycopg2.extensions.register_type(psycopg2.extensions.MXDATEARRAY, self.conn) psycopg2.extensions.register_type(psycopg2.extensions.MXDATEARRAY, self.conn)
psycopg2.extensions.register_type(psycopg2.extensions.MXTIMEARRAY, self.conn) psycopg2.extensions.register_type(psycopg2.extensions.MXTIMEARRAY, self.conn)
psycopg2.extensions.register_type(psycopg2.extensions.MXDATETIMEARRAY, self.conn) psycopg2.extensions.register_type(
psycopg2.extensions.register_type(psycopg2.extensions.MXINTERVALARRAY, self.conn) psycopg2.extensions.MXDATETIMEARRAY, self.conn)
psycopg2.extensions.register_type(
psycopg2.extensions.MXINTERVALARRAY, self.conn)
def tearDown(self): def tearDown(self):
self.conn.close() self.conn.close()
@ -479,15 +482,15 @@ class mxDateTimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date(self): def test_type_roundtrip_date(self):
from mx.DateTime import Date from mx.DateTime import Date
self._test_type_roundtrip(Date(2010,5,3)) self._test_type_roundtrip(Date(2010, 5, 3))
def test_type_roundtrip_datetime(self): def test_type_roundtrip_datetime(self):
from mx.DateTime import DateTime from mx.DateTime import DateTime
self._test_type_roundtrip(DateTime(2010,5,3,10,20,30)) self._test_type_roundtrip(DateTime(2010, 5, 3, 10, 20, 30))
def test_type_roundtrip_time(self): def test_type_roundtrip_time(self):
from mx.DateTime import Time from mx.DateTime import Time
self._test_type_roundtrip(Time(10,20,30)) self._test_type_roundtrip(Time(10, 20, 30))
def test_type_roundtrip_interval(self): def test_type_roundtrip_interval(self):
from mx.DateTime import DateTimeDeltaFrom from mx.DateTime import DateTimeDeltaFrom
@ -495,15 +498,15 @@ class mxDateTimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_type_roundtrip_date_array(self): def test_type_roundtrip_date_array(self):
from mx.DateTime import Date from mx.DateTime import Date
self._test_type_roundtrip_array(Date(2010,5,3)) self._test_type_roundtrip_array(Date(2010, 5, 3))
def test_type_roundtrip_datetime_array(self): def test_type_roundtrip_datetime_array(self):
from mx.DateTime import DateTime from mx.DateTime import DateTime
self._test_type_roundtrip_array(DateTime(2010,5,3,10,20,30)) self._test_type_roundtrip_array(DateTime(2010, 5, 3, 10, 20, 30))
def test_type_roundtrip_time_array(self): def test_type_roundtrip_time_array(self):
from mx.DateTime import Time from mx.DateTime import Time
self._test_type_roundtrip_array(Time(10,20,30)) self._test_type_roundtrip_array(Time(10, 20, 30))
def test_type_roundtrip_interval_array(self): def test_type_roundtrip_interval_array(self):
from mx.DateTime import DateTimeDeltaFrom from mx.DateTime import DateTimeDeltaFrom
@ -549,22 +552,30 @@ class FixedOffsetTimezoneTests(unittest.TestCase):
def test_repr_with_positive_offset(self): def test_repr_with_positive_offset(self):
tzinfo = FixedOffsetTimezone(5 * 60) tzinfo = FixedOffsetTimezone(5 * 60)
self.assertEqual(repr(tzinfo), "psycopg2.tz.FixedOffsetTimezone(offset=300, name=None)") self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=300, name=None)")
def test_repr_with_negative_offset(self): def test_repr_with_negative_offset(self):
tzinfo = FixedOffsetTimezone(-5 * 60) tzinfo = FixedOffsetTimezone(-5 * 60)
self.assertEqual(repr(tzinfo), "psycopg2.tz.FixedOffsetTimezone(offset=-300, name=None)") self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=-300, name=None)")
def test_repr_with_name(self): def test_repr_with_name(self):
tzinfo = FixedOffsetTimezone(name="FOO") tzinfo = FixedOffsetTimezone(name="FOO")
self.assertEqual(repr(tzinfo), "psycopg2.tz.FixedOffsetTimezone(offset=0, name='FOO')") self.assertEqual(repr(tzinfo),
"psycopg2.tz.FixedOffsetTimezone(offset=0, name='FOO')")
def test_instance_caching(self): def test_instance_caching(self):
self.assert_(FixedOffsetTimezone(name="FOO") is FixedOffsetTimezone(name="FOO")) self.assert_(FixedOffsetTimezone(name="FOO")
self.assert_(FixedOffsetTimezone(7 * 60) is FixedOffsetTimezone(7 * 60)) is FixedOffsetTimezone(name="FOO"))
self.assert_(FixedOffsetTimezone(-9 * 60, 'FOO') is FixedOffsetTimezone(-9 * 60, 'FOO')) self.assert_(FixedOffsetTimezone(7 * 60)
self.assert_(FixedOffsetTimezone(9 * 60) is not FixedOffsetTimezone(9 * 60, 'FOO')) is FixedOffsetTimezone(7 * 60))
self.assert_(FixedOffsetTimezone(name='FOO') is not FixedOffsetTimezone(9 * 60, 'FOO')) self.assert_(FixedOffsetTimezone(-9 * 60, 'FOO')
is FixedOffsetTimezone(-9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(9 * 60)
is not FixedOffsetTimezone(9 * 60, 'FOO'))
self.assert_(FixedOffsetTimezone(name='FOO')
is not FixedOffsetTimezone(9 * 60, 'FOO'))
def test_pickle(self): def test_pickle(self):
# ticket #135 # ticket #135

View File

@ -32,6 +32,7 @@ except NameError:
from threading import Thread from threading import Thread
from psycopg2 import errorcodes from psycopg2 import errorcodes
class ErrocodeTests(ConnectingTestCase): class ErrocodeTests(ConnectingTestCase):
def test_lookup_threadsafe(self): def test_lookup_threadsafe(self):
@ -39,6 +40,7 @@ class ErrocodeTests(ConnectingTestCase):
MAX_CYCLES = 2000 MAX_CYCLES = 2000
errs = [] errs = []
def f(pg_code='40001'): def f(pg_code='40001'):
try: try:
errorcodes.lookup(pg_code) errorcodes.lookup(pg_code)

View File

@ -39,7 +39,8 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
self.assertEqual(cur.name, None) self.assertEqual(cur.name, None)
# overridable # overridable
cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.NamedTupleCursor) cur = self.conn.cursor('foo',
cursor_factory=psycopg2.extras.NamedTupleCursor)
self.assertEqual(cur.name, 'foo') self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor)) self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor))
@ -80,7 +81,6 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self.failUnless(row[0] == 'bar') self.failUnless(row[0] == 'bar')
return row return row
def testDictCursorWithPlainCursorRealFetchOne(self): def testDictCursorWithPlainCursorRealFetchOne(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchone()) self._testWithPlainCursorReal(lambda curs: curs.fetchone())
@ -110,7 +110,6 @@ class ExtrasDictCursorTests(ConnectingTestCase):
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
def testDictCursorWithNamedCursorFetchOne(self): def testDictCursorWithNamedCursorFetchOne(self):
self._testWithNamedCursor(lambda curs: curs.fetchone()) self._testWithNamedCursor(lambda curs: curs.fetchone())
@ -146,7 +145,6 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar') self.failUnless(row[0] == 'bar')
def testDictCursorRealWithNamedCursorFetchOne(self): def testDictCursorRealWithNamedCursorFetchOne(self):
self._testWithNamedCursorReal(lambda curs: curs.fetchone()) self._testWithNamedCursorReal(lambda curs: curs.fetchone())
@ -176,12 +174,12 @@ class ExtrasDictCursorTests(ConnectingTestCase):
self._testIterRowNumber(curs) self._testIterRowNumber(curs)
def _testWithNamedCursorReal(self, getter): def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.RealDictCursor) curs = self.conn.cursor('aname',
cursor_factory=psycopg2.extras.RealDictCursor)
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
def _testNamedCursorNotGreedy(self, curs): def _testNamedCursorNotGreedy(self, curs):
curs.itersize = 2 curs.itersize = 2
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""") curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""")
@ -235,7 +233,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
from psycopg2.extras import NamedTupleConnection from psycopg2.extras import NamedTupleConnection
try: try:
from collections import namedtuple from collections import namedtuple # noqa
except ImportError: except ImportError:
return return
@ -346,7 +344,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
def test_error_message(self): def test_error_message(self):
try: try:
from collections import namedtuple from collections import namedtuple # noqa
except ImportError: except ImportError:
# an import error somewhere # an import error somewhere
from psycopg2.extras import NamedTupleConnection from psycopg2.extras import NamedTupleConnection
@ -390,6 +388,7 @@ class NamedTupleCursorTest(ConnectingTestCase):
from psycopg2.extras import NamedTupleCursor from psycopg2.extras import NamedTupleCursor
f_orig = NamedTupleCursor._make_nt f_orig = NamedTupleCursor._make_nt
calls = [0] calls = [0]
def f_patched(self_): def f_patched(self_):
calls[0] += 1 calls[0] += 1
return f_orig(self_) return f_orig(self_)

View File

@ -29,6 +29,7 @@ import psycopg2.extras
from testutils import ConnectingTestCase from testutils import ConnectingTestCase
class ConnectionStub(object): class ConnectionStub(object):
"""A `connection` wrapper allowing analysis of the `poll()` calls.""" """A `connection` wrapper allowing analysis of the `poll()` calls."""
def __init__(self, conn): def __init__(self, conn):
@ -43,6 +44,7 @@ class ConnectionStub(object):
self.polls.append(rv) self.polls.append(rv)
return rv return rv
class GreenTestCase(ConnectingTestCase): class GreenTestCase(ConnectingTestCase):
def setUp(self): def setUp(self):
self._cb = psycopg2.extensions.get_wait_callback() self._cb = psycopg2.extensions.get_wait_callback()
@ -89,7 +91,7 @@ class GreenTestCase(ConnectingTestCase):
curs.fetchone() curs.fetchone()
# now try to do something that will fail in the callback # now try to do something that will fail in the callback
psycopg2.extensions.set_wait_callback(lambda conn: 1//0) psycopg2.extensions.set_wait_callback(lambda conn: 1 // 0)
self.assertRaises(ZeroDivisionError, curs.execute, "select 2") self.assertRaises(ZeroDivisionError, curs.execute, "select 2")
self.assert_(conn.closed) self.assert_(conn.closed)

View File

@ -32,6 +32,7 @@ import psycopg2.extensions
from testutils import unittest, decorate_all_tests, skip_if_tpc_disabled from testutils import unittest, decorate_all_tests, skip_if_tpc_disabled
from testutils import ConnectingTestCase, skip_if_green from testutils import ConnectingTestCase, skip_if_green
def skip_if_no_lo(f): def skip_if_no_lo(f):
@wraps(f) @wraps(f)
def skip_if_no_lo_(self): def skip_if_no_lo_(self):
@ -158,7 +159,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read(self): def test_read(self):
lo = self.conn.lobject() lo = self.conn.lobject()
length = lo.write(b"some data") lo.write(b"some data")
lo.close() lo.close()
lo = self.conn.lobject(lo.oid) lo = self.conn.lobject(lo.oid)
@ -169,7 +170,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read_binary(self): def test_read_binary(self):
lo = self.conn.lobject() lo = self.conn.lobject()
length = lo.write(b"some data") lo.write(b"some data")
lo.close() lo.close()
lo = self.conn.lobject(lo.oid, "rb") lo = self.conn.lobject(lo.oid, "rb")
@ -181,7 +182,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read_text(self): def test_read_text(self):
lo = self.conn.lobject() lo = self.conn.lobject()
snowman = u"\u2603" snowman = u"\u2603"
length = lo.write(u"some data " + snowman) lo.write(u"some data " + snowman)
lo.close() lo.close()
lo = self.conn.lobject(lo.oid, "rt") lo = self.conn.lobject(lo.oid, "rt")
@ -193,7 +194,7 @@ class LargeObjectTests(LargeObjectTestCase):
def test_read_large(self): def test_read_large(self):
lo = self.conn.lobject() lo = self.conn.lobject()
data = "data" * 1000000 data = "data" * 1000000
length = lo.write("some" + data) lo.write("some" + data)
lo.close() lo.close()
lo = self.conn.lobject(lo.oid) lo = self.conn.lobject(lo.oid)
@ -399,6 +400,7 @@ def skip_if_no_truncate(f):
return skip_if_no_truncate_ return skip_if_no_truncate_
class LargeObjectTruncateTests(LargeObjectTestCase): class LargeObjectTruncateTests(LargeObjectTestCase):
def test_truncate(self): def test_truncate(self):
lo = self.conn.lobject() lo = self.conn.lobject()
@ -450,15 +452,19 @@ def _has_lo64(conn):
return (True, "this server and build support the lo64 API") return (True, "this server and build support the lo64 API")
def skip_if_no_lo64(f): def skip_if_no_lo64(f):
@wraps(f) @wraps(f)
def skip_if_no_lo64_(self): def skip_if_no_lo64_(self):
lo64, msg = _has_lo64(self.conn) lo64, msg = _has_lo64(self.conn)
if not lo64: return self.skipTest(msg) if not lo64:
else: return f(self) return self.skipTest(msg)
else:
return f(self)
return skip_if_no_lo64_ return skip_if_no_lo64_
class LargeObject64Tests(LargeObjectTestCase): class LargeObject64Tests(LargeObjectTestCase):
def test_seek_tell_truncate_greater_than_2gb(self): def test_seek_tell_truncate_greater_than_2gb(self):
lo = self.conn.lobject() lo = self.conn.lobject()
@ -477,11 +483,14 @@ def skip_if_lo64(f):
@wraps(f) @wraps(f)
def skip_if_lo64_(self): def skip_if_lo64_(self):
lo64, msg = _has_lo64(self.conn) lo64, msg = _has_lo64(self.conn)
if lo64: return self.skipTest(msg) if lo64:
else: return f(self) return self.skipTest(msg)
else:
return f(self)
return skip_if_lo64_ return skip_if_lo64_
class LargeObjectNot64Tests(LargeObjectTestCase): class LargeObjectNot64Tests(LargeObjectTestCase):
def test_seek_larger_than_2gb(self): def test_seek_larger_than_2gb(self):
lo = self.conn.lobject() lo = self.conn.lobject()

View File

@ -79,7 +79,7 @@ conn.close()
proc = self.notify('foo', 1) proc = self.notify('foo', 1)
t0 = time.time() t0 = time.time()
ready = select.select([self.conn], [], [], 5) select.select([self.conn], [], [], 5)
t1 = time.time() t1 = time.time()
self.assert_(0.99 < t1 - t0 < 4, t1 - t0) self.assert_(0.99 < t1 - t0 < 4, t1 - t0)
@ -217,6 +217,6 @@ conn.close()
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -30,6 +30,7 @@ import psycopg2
from testconfig import dsn from testconfig import dsn
class Psycopg2Tests(dbapi20.DatabaseAPI20Test): class Psycopg2Tests(dbapi20.DatabaseAPI20Test):
driver = psycopg2 driver = psycopg2
connect_args = () connect_args = ()

View File

@ -24,8 +24,8 @@
import psycopg2 import psycopg2
import psycopg2.extensions import psycopg2.extensions
from psycopg2.extras import PhysicalReplicationConnection, LogicalReplicationConnection from psycopg2.extras import (
from psycopg2.extras import StopReplication PhysicalReplicationConnection, LogicalReplicationConnection, StopReplication)
import testconfig import testconfig
from testutils import unittest from testutils import unittest
@ -70,14 +70,16 @@ class ReplicationTestCase(ConnectingTestCase):
# generate some events for our replication stream # generate some events for our replication stream
def make_replication_events(self): def make_replication_events(self):
conn = self.connect() conn = self.connect()
if conn is None: return if conn is None:
return
cur = conn.cursor() cur = conn.cursor()
try: try:
cur.execute("DROP TABLE dummy1") cur.execute("DROP TABLE dummy1")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
conn.rollback() conn.rollback()
cur.execute("CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id") cur.execute(
"CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
conn.commit() conn.commit()
@ -85,7 +87,8 @@ class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 0) @skip_before_postgres(9, 0)
def test_physical_replication_connection(self): def test_physical_replication_connection(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return if conn is None:
return
cur = conn.cursor() cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@ -93,7 +96,8 @@ class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) @skip_before_postgres(9, 4)
def test_logical_replication_connection(self): def test_logical_replication_connection(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: return if conn is None:
return
cur = conn.cursor() cur = conn.cursor()
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@ -101,19 +105,23 @@ class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_create_replication_slot(self): def test_create_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return if conn is None:
return
cur = conn.cursor() cur = conn.cursor()
self.create_replication_slot(cur) self.create_replication_slot(cur)
self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot, cur) self.assertRaises(
psycopg2.ProgrammingError, self.create_replication_slot, cur)
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_start_on_missing_replication_slot(self): def test_start_on_missing_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: return if conn is None:
return
cur = conn.cursor() cur = conn.cursor()
self.assertRaises(psycopg2.ProgrammingError, cur.start_replication, self.slot) self.assertRaises(psycopg2.ProgrammingError,
cur.start_replication, self.slot)
self.create_replication_slot(cur) self.create_replication_slot(cur)
cur.start_replication(self.slot) cur.start_replication(self.slot)
@ -121,13 +129,16 @@ class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_start_and_recover_from_error(self): def test_start_and_recover_from_error(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: return if conn is None:
return
cur = conn.cursor() cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding') self.create_replication_slot(cur, output_plugin='test_decoding')
# try with invalid options # try with invalid options
cur.start_replication(slot_name=self.slot, options={'invalid_param': 'value'}) cur.start_replication(
slot_name=self.slot, options={'invalid_param': 'value'})
def consume(msg): def consume(msg):
pass pass
# we don't see the error from the server before we try to read the data # we don't see the error from the server before we try to read the data
@ -139,7 +150,8 @@ class ReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_stop_replication(self): def test_stop_replication(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: return if conn is None:
return
cur = conn.cursor() cur = conn.cursor()
self.create_replication_slot(cur, output_plugin='test_decoding') self.create_replication_slot(cur, output_plugin='test_decoding')
@ -147,6 +159,7 @@ class ReplicationTest(ReplicationTestCase):
self.make_replication_events() self.make_replication_events()
cur.start_replication(self.slot) cur.start_replication(self.slot)
def consume(msg): def consume(msg):
raise StopReplication() raise StopReplication()
self.assertRaises(StopReplication, cur.consume_stream, consume) self.assertRaises(StopReplication, cur.consume_stream, consume)
@ -155,8 +168,10 @@ class ReplicationTest(ReplicationTestCase):
class AsyncReplicationTest(ReplicationTestCase): class AsyncReplicationTest(ReplicationTestCase):
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_async_replication(self): def test_async_replication(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection, async=1) conn = self.repl_connect(
if conn is None: return connection_factory=LogicalReplicationConnection, async=1)
if conn is None:
return
self.wait(conn) self.wait(conn)
cur = conn.cursor() cur = conn.cursor()
@ -169,9 +184,10 @@ class AsyncReplicationTest(ReplicationTestCase):
self.make_replication_events() self.make_replication_events()
self.msg_count = 0 self.msg_count = 0
def consume(msg): def consume(msg):
# just check the methods # just check the methods
log = "%s: %s" % (cur.io_timestamp, repr(msg)) "%s: %s" % (cur.io_timestamp, repr(msg))
self.msg_count += 1 self.msg_count += 1
if self.msg_count > 3: if self.msg_count > 3:
@ -193,8 +209,10 @@ class AsyncReplicationTest(ReplicationTestCase):
select([cur], [], []) select([cur], [], [])
self.assertRaises(StopReplication, process_stream) self.assertRaises(StopReplication, process_stream)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -29,6 +29,7 @@ import psycopg2
from psycopg2.extensions import ( from psycopg2.extensions import (
ISOLATION_LEVEL_SERIALIZABLE, STATUS_BEGIN, STATUS_READY) ISOLATION_LEVEL_SERIALIZABLE, STATUS_BEGIN, STATUS_READY)
class TransactionTests(ConnectingTestCase): class TransactionTests(ConnectingTestCase):
def setUp(self): def setUp(self):
@ -147,6 +148,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
self.thread1_error = exc self.thread1_error = exc
step1.set() step1.set()
conn.close() conn.close()
def task2(): def task2():
try: try:
conn = self.connect() conn = self.connect()
@ -195,6 +197,7 @@ class DeadlockSerializationTests(ConnectingTestCase):
self.thread1_error = exc self.thread1_error = exc
step1.set() step1.set()
conn.close() conn.close()
def task2(): def task2():
try: try:
conn = self.connect() conn = self.connect()

View File

@ -68,13 +68,16 @@ class TypesBasicTests(ConnectingTestCase):
"wrong decimal quoting: " + str(s)) "wrong decimal quoting: " + str(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("NaN"),)) s = self.execute("SELECT %s AS foo", (decimal.Decimal("NaN"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s)) self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s)) self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("infinity"),)) s = self.execute("SELECT %s AS foo", (decimal.Decimal("infinity"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s)) self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s)) self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
s = self.execute("SELECT %s AS foo", (decimal.Decimal("-infinity"),)) s = self.execute("SELECT %s AS foo", (decimal.Decimal("-infinity"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s)) self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s)) self.failUnless(type(s) == decimal.Decimal,
"wrong decimal conversion: " + repr(s))
def testFloatNan(self): def testFloatNan(self):
try: try:
@ -141,8 +144,8 @@ class TypesBasicTests(ConnectingTestCase):
self.assertEqual(s, buf2.tobytes()) self.assertEqual(s, buf2.tobytes())
def testArray(self): def testArray(self):
s = self.execute("SELECT %s AS foo", ([[1,2],[3,4]],)) s = self.execute("SELECT %s AS foo", ([[1, 2], [3, 4]],))
self.failUnlessEqual(s, [[1,2],[3,4]]) self.failUnlessEqual(s, [[1, 2], [3, 4]])
s = self.execute("SELECT %s AS foo", (['one', 'two', 'three'],)) s = self.execute("SELECT %s AS foo", (['one', 'two', 'three'],))
self.failUnlessEqual(s, ['one', 'two', 'three']) self.failUnlessEqual(s, ['one', 'two', 'three'])
@ -150,9 +153,12 @@ class TypesBasicTests(ConnectingTestCase):
# ticket #42 # ticket #42
import datetime import datetime
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("create table array_test (id integer, col timestamp without time zone[])") curs.execute(
"create table array_test "
"(id integer, col timestamp without time zone[])")
curs.execute("insert into array_test values (%s, %s)", (1, [datetime.date(2011,2,14)])) curs.execute("insert into array_test values (%s, %s)",
(1, [datetime.date(2011, 2, 14)]))
curs.execute("select col from array_test where id = 1") curs.execute("select col from array_test where id = 1")
self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)]) self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)])
@ -323,30 +329,33 @@ class TypesBasicTests(ConnectingTestCase):
self.assertEqual(1, l1) self.assertEqual(1, l1)
def testGenericArray(self): def testGenericArray(self):
a = self.execute("select '{1,2,3}'::int4[]") a = self.execute("select '{1, 2, 3}'::int4[]")
self.assertEqual(a, [1,2,3]) self.assertEqual(a, [1, 2, 3])
a = self.execute("select array['a','b','''']::text[]") a = self.execute("select array['a', 'b', '''']::text[]")
self.assertEqual(a, ['a','b',"'"]) self.assertEqual(a, ['a', 'b', "'"])
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def testGenericArrayNull(self): def testGenericArrayNull(self):
def caster(s, cur): def caster(s, cur):
if s is None: return "nada" if s is None:
return "nada"
return int(s) * 2 return int(s) * 2
base = psycopg2.extensions.new_type((23,), "INT4", caster) base = psycopg2.extensions.new_type((23,), "INT4", caster)
array = psycopg2.extensions.new_array_type((1007,), "INT4ARRAY", base) array = psycopg2.extensions.new_array_type((1007,), "INT4ARRAY", base)
psycopg2.extensions.register_type(array, self.conn) psycopg2.extensions.register_type(array, self.conn)
a = self.execute("select '{1,2,3}'::int4[]") a = self.execute("select '{1, 2, 3}'::int4[]")
self.assertEqual(a, [2,4,6]) self.assertEqual(a, [2, 4, 6])
a = self.execute("select '{1,2,NULL}'::int4[]") a = self.execute("select '{1, 2, NULL}'::int4[]")
self.assertEqual(a, [2,4,'nada']) self.assertEqual(a, [2, 4, 'nada'])
class AdaptSubclassTest(unittest.TestCase): class AdaptSubclassTest(unittest.TestCase):
def test_adapt_subtype(self): def test_adapt_subtype(self):
from psycopg2.extensions import adapt from psycopg2.extensions import adapt
class Sub(str): pass
class Sub(str):
pass
s1 = "hel'lo" s1 = "hel'lo"
s2 = Sub(s1) s2 = Sub(s1)
self.assertEqual(adapt(s1).getquoted(), adapt(s2).getquoted()) self.assertEqual(adapt(s1).getquoted(), adapt(s2).getquoted())
@ -354,9 +363,14 @@ class AdaptSubclassTest(unittest.TestCase):
def test_adapt_most_specific(self): def test_adapt_most_specific(self):
from psycopg2.extensions import adapt, register_adapter, AsIs from psycopg2.extensions import adapt, register_adapter, AsIs
class A(object): pass class A(object):
class B(A): pass pass
class C(B): pass
class B(A):
pass
class C(B):
pass
register_adapter(A, lambda a: AsIs("a")) register_adapter(A, lambda a: AsIs("a"))
register_adapter(B, lambda b: AsIs("b")) register_adapter(B, lambda b: AsIs("b"))
@ -370,8 +384,11 @@ class AdaptSubclassTest(unittest.TestCase):
def test_no_mro_no_joy(self): def test_no_mro_no_joy(self):
from psycopg2.extensions import adapt, register_adapter, AsIs from psycopg2.extensions import adapt, register_adapter, AsIs
class A: pass class A:
class B(A): pass pass
class B(A):
pass
register_adapter(A, lambda a: AsIs("a")) register_adapter(A, lambda a: AsIs("a"))
try: try:
@ -383,8 +400,11 @@ class AdaptSubclassTest(unittest.TestCase):
def test_adapt_subtype_3(self): def test_adapt_subtype_3(self):
from psycopg2.extensions import adapt, register_adapter, AsIs from psycopg2.extensions import adapt, register_adapter, AsIs
class A: pass class A:
class B(A): pass pass
class B(A):
pass
register_adapter(A, lambda a: AsIs("a")) register_adapter(A, lambda a: AsIs("a"))
try: try:
@ -443,7 +463,8 @@ class ByteaParserTest(unittest.TestCase):
def test_full_hex(self, upper=False): def test_full_hex(self, upper=False):
buf = ''.join(("%02x" % i) for i in range(256)) buf = ''.join(("%02x" % i) for i in range(256))
if upper: buf = buf.upper() if upper:
buf = buf.upper()
buf = '\\x' + buf buf = '\\x' + buf
rv = self.cast(buf.encode('utf8')) rv = self.cast(buf.encode('utf8'))
if sys.version_info[0] < 3: if sys.version_info[0] < 3:

View File

@ -37,6 +37,7 @@ def filter_scs(conn, s):
else: else:
return s.replace(b"E'", b"'") return s.replace(b"E'", b"'")
class TypesExtrasTests(ConnectingTestCase): class TypesExtrasTests(ConnectingTestCase):
"""Test that all type conversions are working.""" """Test that all type conversions are working."""
@ -60,7 +61,8 @@ class TypesExtrasTests(ConnectingTestCase):
def testUUIDARRAY(self): def testUUIDARRAY(self):
import uuid import uuid
psycopg2.extras.register_uuid() psycopg2.extras.register_uuid()
u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'), uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e352')] u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'),
uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e352')]
s = self.execute("SELECT %s AS foo", (u,)) s = self.execute("SELECT %s AS foo", (u,))
self.failUnless(u == s) self.failUnless(u == s)
# array with a NULL element # array with a NULL element
@ -110,7 +112,8 @@ class TypesExtrasTests(ConnectingTestCase):
a.getquoted()) a.getquoted())
def test_adapt_fail(self): def test_adapt_fail(self):
class Foo(object): pass class Foo(object):
pass
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
psycopg2.extensions.adapt, Foo(), ext.ISQLQuote, None) psycopg2.extensions.adapt, Foo(), ext.ISQLQuote, None)
try: try:
@ -130,6 +133,7 @@ def skip_if_no_hstore(f):
return skip_if_no_hstore_ return skip_if_no_hstore_
class HstoreTestCase(ConnectingTestCase): class HstoreTestCase(ConnectingTestCase):
def test_adapt_8(self): def test_adapt_8(self):
if self.conn.server_version >= 90000: if self.conn.server_version >= 90000:
@ -155,7 +159,8 @@ class HstoreTestCase(ConnectingTestCase):
self.assertEqual(ii[2], filter_scs(self.conn, b"(E'c' => NULL)")) self.assertEqual(ii[2], filter_scs(self.conn, b"(E'c' => NULL)"))
if 'd' in o: if 'd' in o:
encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding]) encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding])
self.assertEqual(ii[3], filter_scs(self.conn, b"(E'd' => E'" + encc + b"')")) self.assertEqual(ii[3],
filter_scs(self.conn, b"(E'd' => E'" + encc + b"')"))
def test_adapt_9(self): def test_adapt_9(self):
if self.conn.server_version < 90000: if self.conn.server_version < 90000:
@ -199,7 +204,7 @@ class HstoreTestCase(ConnectingTestCase):
ok(None, None) ok(None, None)
ok('', {}) ok('', {})
ok('"a"=>"1", "b"=>"2"', {'a': '1', 'b': '2'}) ok('"a"=>"1", "b"=>"2"', {'a': '1', 'b': '2'})
ok('"a" => "1" ,"b" => "2"', {'a': '1', 'b': '2'}) ok('"a" => "1" , "b" => "2"', {'a': '1', 'b': '2'})
ok('"a"=>NULL, "b"=>"2"', {'a': None, 'b': '2'}) ok('"a"=>NULL, "b"=>"2"', {'a': None, 'b': '2'})
ok(r'"a"=>"\"", "\""=>"2"', {'a': '"', '"': '2'}) ok(r'"a"=>"\"", "\""=>"2"', {'a': '"', '"': '2'})
ok('"a"=>"\'", "\'"=>"2"', {'a': "'", "'": '2'}) ok('"a"=>"\'", "\'"=>"2"', {'a': "'", "'": '2'})
@ -402,7 +407,9 @@ class HstoreTestCase(ConnectingTestCase):
from psycopg2.extras import register_hstore from psycopg2.extras import register_hstore
register_hstore(None, globally=True, oid=oid, array_oid=aoid) register_hstore(None, globally=True, oid=oid, array_oid=aoid)
try: try:
cur.execute("select null::hstore, ''::hstore, 'a => b'::hstore, '{a=>b}'::hstore[]") cur.execute("""
select null::hstore, ''::hstore,
'a => b'::hstore, '{a=>b}'::hstore[]""")
t = cur.fetchone() t = cur.fetchone()
self.assert_(t[0] is None) self.assert_(t[0] is None)
self.assertEqual(t[1], {}) self.assertEqual(t[1], {})
@ -449,6 +456,7 @@ def skip_if_no_composite(f):
return skip_if_no_composite_ return skip_if_no_composite_
class AdaptTypeTestCase(ConnectingTestCase): class AdaptTypeTestCase(ConnectingTestCase):
@skip_if_no_composite @skip_if_no_composite
def test_none_in_record(self): def test_none_in_record(self):
@ -463,8 +471,11 @@ class AdaptTypeTestCase(ConnectingTestCase):
# the None adapter is not actually invoked in regular adaptation # the None adapter is not actually invoked in regular adaptation
class WonkyAdapter(object): class WonkyAdapter(object):
def __init__(self, obj): pass def __init__(self, obj):
def getquoted(self): return "NOPE!" pass
def getquoted(self):
return "NOPE!"
curs = self.conn.cursor() curs = self.conn.cursor()
@ -481,6 +492,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
def test_tokenization(self): def test_tokenization(self):
from psycopg2.extras import CompositeCaster from psycopg2.extras import CompositeCaster
def ok(s, v): def ok(s, v):
self.assertEqual(CompositeCaster.tokenize(s), v) self.assertEqual(CompositeCaster.tokenize(s), v)
@ -519,26 +531,26 @@ class AdaptTypeTestCase(ConnectingTestCase):
self.assertEqual(t.oid, oid) self.assertEqual(t.oid, oid)
self.assert_(issubclass(t.type, tuple)) self.assert_(issubclass(t.type, tuple))
self.assertEqual(t.attnames, ['anint', 'astring', 'adate']) self.assertEqual(t.attnames, ['anint', 'astring', 'adate'])
self.assertEqual(t.atttypes, [23,25,1082]) self.assertEqual(t.atttypes, [23, 25, 1082])
curs = self.conn.cursor() curs = self.conn.cursor()
r = (10, 'hello', date(2011,1,2)) r = (10, 'hello', date(2011, 1, 2))
curs.execute("select %s::type_isd;", (r,)) curs.execute("select %s::type_isd;", (r,))
v = curs.fetchone()[0] v = curs.fetchone()[0]
self.assert_(isinstance(v, t.type)) self.assert_(isinstance(v, t.type))
self.assertEqual(v[0], 10) self.assertEqual(v[0], 10)
self.assertEqual(v[1], "hello") self.assertEqual(v[1], "hello")
self.assertEqual(v[2], date(2011,1,2)) self.assertEqual(v[2], date(2011, 1, 2))
try: try:
from collections import namedtuple from collections import namedtuple # noqa
except ImportError: except ImportError:
pass pass
else: else:
self.assert_(t.type is not tuple) self.assert_(t.type is not tuple)
self.assertEqual(v.anint, 10) self.assertEqual(v.anint, 10)
self.assertEqual(v.astring, "hello") self.assertEqual(v.astring, "hello")
self.assertEqual(v.adate, date(2011,1,2)) self.assertEqual(v.adate, date(2011, 1, 2))
@skip_if_no_composite @skip_if_no_composite
def test_empty_string(self): def test_empty_string(self):
@ -574,14 +586,14 @@ class AdaptTypeTestCase(ConnectingTestCase):
psycopg2.extras.register_composite("type_r_ft", self.conn) psycopg2.extras.register_composite("type_r_ft", self.conn)
curs = self.conn.cursor() curs = self.conn.cursor()
r = (0.25, (date(2011,1,2), (42, "hello"))) r = (0.25, (date(2011, 1, 2), (42, "hello")))
curs.execute("select %s::type_r_ft;", (r,)) curs.execute("select %s::type_r_ft;", (r,))
v = curs.fetchone()[0] v = curs.fetchone()[0]
self.assertEqual(r, v) self.assertEqual(r, v)
try: try:
from collections import namedtuple from collections import namedtuple # noqa
except ImportError: except ImportError:
pass pass
else: else:
@ -595,7 +607,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs2 = self.conn.cursor() curs2 = self.conn.cursor()
psycopg2.extras.register_composite("type_ii", curs1) psycopg2.extras.register_composite("type_ii", curs1)
curs1.execute("select (1,2)::type_ii") curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1,2)) self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii") curs2.execute("select (1,2)::type_ii")
self.assertEqual(curs2.fetchone()[0], "(1,2)") self.assertEqual(curs2.fetchone()[0], "(1,2)")
@ -610,7 +622,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs1 = conn1.cursor() curs1 = conn1.cursor()
curs2 = conn2.cursor() curs2 = conn2.cursor()
curs1.execute("select (1,2)::type_ii") curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1,2)) self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii") curs2.execute("select (1,2)::type_ii")
self.assertEqual(curs2.fetchone()[0], "(1,2)") self.assertEqual(curs2.fetchone()[0], "(1,2)")
finally: finally:
@ -629,9 +641,9 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs1 = conn1.cursor() curs1 = conn1.cursor()
curs2 = conn2.cursor() curs2 = conn2.cursor()
curs1.execute("select (1,2)::type_ii") curs1.execute("select (1,2)::type_ii")
self.assertEqual(curs1.fetchone()[0], (1,2)) self.assertEqual(curs1.fetchone()[0], (1, 2))
curs2.execute("select (1,2)::type_ii") curs2.execute("select (1,2)::type_ii")
self.assertEqual(curs2.fetchone()[0], (1,2)) self.assertEqual(curs2.fetchone()[0], (1, 2))
finally: finally:
# drop the registered typecasters to help the refcounting # drop the registered typecasters to help the refcounting
# script to return precise values. # script to return precise values.
@ -661,30 +673,30 @@ class AdaptTypeTestCase(ConnectingTestCase):
"typens.typens_ii", self.conn) "typens.typens_ii", self.conn)
self.assertEqual(t.schema, 'typens') self.assertEqual(t.schema, 'typens')
curs.execute("select (4,8)::typens.typens_ii") curs.execute("select (4,8)::typens.typens_ii")
self.assertEqual(curs.fetchone()[0], (4,8)) self.assertEqual(curs.fetchone()[0], (4, 8))
@skip_if_no_composite @skip_if_no_composite
@skip_before_postgres(8, 4) @skip_before_postgres(8, 4)
def test_composite_array(self): def test_composite_array(self):
oid = self._create_type("type_isd", self._create_type("type_isd",
[('anint', 'integer'), ('astring', 'text'), ('adate', 'date')]) [('anint', 'integer'), ('astring', 'text'), ('adate', 'date')])
t = psycopg2.extras.register_composite("type_isd", self.conn) t = psycopg2.extras.register_composite("type_isd", self.conn)
curs = self.conn.cursor() curs = self.conn.cursor()
r1 = (10, 'hello', date(2011,1,2)) r1 = (10, 'hello', date(2011, 1, 2))
r2 = (20, 'world', date(2011,1,3)) r2 = (20, 'world', date(2011, 1, 3))
curs.execute("select %s::type_isd[];", ([r1, r2],)) curs.execute("select %s::type_isd[];", ([r1, r2],))
v = curs.fetchone()[0] v = curs.fetchone()[0]
self.assertEqual(len(v), 2) self.assertEqual(len(v), 2)
self.assert_(isinstance(v[0], t.type)) self.assert_(isinstance(v[0], t.type))
self.assertEqual(v[0][0], 10) self.assertEqual(v[0][0], 10)
self.assertEqual(v[0][1], "hello") self.assertEqual(v[0][1], "hello")
self.assertEqual(v[0][2], date(2011,1,2)) self.assertEqual(v[0][2], date(2011, 1, 2))
self.assert_(isinstance(v[1], t.type)) self.assert_(isinstance(v[1], t.type))
self.assertEqual(v[1][0], 20) self.assertEqual(v[1][0], 20)
self.assertEqual(v[1][1], "world") self.assertEqual(v[1][1], "world")
self.assertEqual(v[1][2], date(2011,1,3)) self.assertEqual(v[1][2], date(2011, 1, 3))
@skip_if_no_composite @skip_if_no_composite
def test_wrong_schema(self): def test_wrong_schema(self):
@ -752,7 +764,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
register_composite('type_ii', conn) register_composite('type_ii', conn)
curs = conn.cursor() curs = conn.cursor()
curs.execute("select '(1,2)'::type_ii as x") curs.execute("select '(1,2)'::type_ii as x")
self.assertEqual(curs.fetchone()['x'], (1,2)) self.assertEqual(curs.fetchone()['x'], (1, 2))
finally: finally:
conn.close() conn.close()
@ -761,7 +773,7 @@ class AdaptTypeTestCase(ConnectingTestCase):
curs = conn.cursor() curs = conn.cursor()
register_composite('type_ii', conn) register_composite('type_ii', conn)
curs.execute("select '(1,2)'::type_ii as x") curs.execute("select '(1,2)'::type_ii as x")
self.assertEqual(curs.fetchone()['x'], (1,2)) self.assertEqual(curs.fetchone()['x'], (1, 2))
finally: finally:
conn.close() conn.close()
@ -782,13 +794,13 @@ class AdaptTypeTestCase(ConnectingTestCase):
self.assertEqual(t.oid, oid) self.assertEqual(t.oid, oid)
curs = self.conn.cursor() curs = self.conn.cursor()
r = (10, 'hello', date(2011,1,2)) r = (10, 'hello', date(2011, 1, 2))
curs.execute("select %s::type_isd;", (r,)) curs.execute("select %s::type_isd;", (r,))
v = curs.fetchone()[0] v = curs.fetchone()[0]
self.assert_(isinstance(v, dict)) self.assert_(isinstance(v, dict))
self.assertEqual(v['anint'], 10) self.assertEqual(v['anint'], 10)
self.assertEqual(v['astring'], "hello") self.assertEqual(v['astring'], "hello")
self.assertEqual(v['adate'], date(2011,1,2)) self.assertEqual(v['adate'], date(2011, 1, 2))
def _create_type(self, name, fields): def _create_type(self, name, fields):
curs = self.conn.cursor() curs = self.conn.cursor()
@ -825,6 +837,7 @@ def skip_if_json_module(f):
return skip_if_json_module_ return skip_if_json_module_
def skip_if_no_json_module(f): def skip_if_no_json_module(f):
"""Skip a test if no Python json module is available""" """Skip a test if no Python json module is available"""
@wraps(f) @wraps(f)
@ -836,6 +849,7 @@ def skip_if_no_json_module(f):
return skip_if_no_json_module_ return skip_if_no_json_module_
def skip_if_no_json_type(f): def skip_if_no_json_type(f):
"""Skip a test if PostgreSQL json type is not available""" """Skip a test if PostgreSQL json type is not available"""
@wraps(f) @wraps(f)
@ -849,6 +863,7 @@ def skip_if_no_json_type(f):
return skip_if_no_json_type_ return skip_if_no_json_type_
class JsonTestCase(ConnectingTestCase): class JsonTestCase(ConnectingTestCase):
@skip_if_json_module @skip_if_json_module
def test_module_not_available(self): def test_module_not_available(self):
@ -858,6 +873,7 @@ class JsonTestCase(ConnectingTestCase):
@skip_if_json_module @skip_if_json_module
def test_customizable_with_module_not_available(self): def test_customizable_with_module_not_available(self):
from psycopg2.extras import Json from psycopg2.extras import Json
class MyJson(Json): class MyJson(Json):
def dumps(self, obj): def dumps(self, obj):
assert obj is None assert obj is None
@ -870,7 +886,7 @@ class JsonTestCase(ConnectingTestCase):
from psycopg2.extras import json, Json from psycopg2.extras import json, Json
objs = [None, "te'xt", 123, 123.45, objs = [None, "te'xt", 123, 123.45,
u'\xe0\u20ac', ['a', 100], {'a': 100} ] u'\xe0\u20ac', ['a', 100], {'a': 100}]
curs = self.conn.cursor() curs = self.conn.cursor()
for obj in enumerate(objs): for obj in enumerate(objs):
@ -889,7 +905,9 @@ class JsonTestCase(ConnectingTestCase):
curs = self.conn.cursor() curs = self.conn.cursor()
obj = Decimal('123.45') obj = Decimal('123.45')
dumps = lambda obj: json.dumps(obj, cls=DecimalEncoder)
def dumps(obj):
return json.dumps(obj, cls=DecimalEncoder)
self.assertEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)), self.assertEqual(curs.mogrify("%s", (Json(obj, dumps=dumps),)),
b"'123.45'") b"'123.45'")
@ -923,7 +941,6 @@ class JsonTestCase(ConnectingTestCase):
finally: finally:
del psycopg2.extensions.adapters[dict, ext.ISQLQuote] del psycopg2.extensions.adapters[dict, ext.ISQLQuote]
def test_type_not_available(self): def test_type_not_available(self):
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("select oid from pg_type where typname = 'json'") curs.execute("select oid from pg_type where typname = 'json'")
@ -982,7 +999,9 @@ class JsonTestCase(ConnectingTestCase):
@skip_if_no_json_type @skip_if_no_json_type
def test_loads(self): def test_loads(self):
json = psycopg2.extras.json json = psycopg2.extras.json
loads = lambda x: json.loads(x, parse_float=Decimal)
def loads(s):
return json.loads(s, parse_float=Decimal)
psycopg2.extras.register_json(self.conn, loads=loads) psycopg2.extras.register_json(self.conn, loads=loads)
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
@ -998,7 +1017,9 @@ class JsonTestCase(ConnectingTestCase):
old = psycopg2.extensions.string_types.get(114) old = psycopg2.extensions.string_types.get(114)
olda = psycopg2.extensions.string_types.get(199) olda = psycopg2.extensions.string_types.get(199)
loads = lambda x: psycopg2.extras.json.loads(x, parse_float=Decimal)
def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal)
try: try:
new, newa = psycopg2.extras.register_json( new, newa = psycopg2.extras.register_json(
loads=loads, oid=oid, array_oid=array_oid) loads=loads, oid=oid, array_oid=array_oid)
@ -1020,7 +1041,8 @@ class JsonTestCase(ConnectingTestCase):
def test_register_default(self): def test_register_default(self):
curs = self.conn.cursor() curs = self.conn.cursor()
loads = lambda x: psycopg2.extras.json.loads(x, parse_float=Decimal) def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal)
psycopg2.extras.register_default_json(curs, loads=loads) psycopg2.extras.register_default_json(curs, loads=loads)
curs.execute("""select '{"a": 100.0, "b": null}'::json""") curs.execute("""select '{"a": 100.0, "b": null}'::json""")
@ -1070,6 +1092,7 @@ class JsonTestCase(ConnectingTestCase):
def skip_if_no_jsonb_type(f): def skip_if_no_jsonb_type(f):
return skip_before_postgres(9, 4)(f) return skip_before_postgres(9, 4)(f)
class JsonbTestCase(ConnectingTestCase): class JsonbTestCase(ConnectingTestCase):
@staticmethod @staticmethod
def myloads(s): def myloads(s):
@ -1118,7 +1141,10 @@ class JsonbTestCase(ConnectingTestCase):
def test_loads(self): def test_loads(self):
json = psycopg2.extras.json json = psycopg2.extras.json
loads = lambda x: json.loads(x, parse_float=Decimal)
def loads(s):
return json.loads(s, parse_float=Decimal)
psycopg2.extras.register_json(self.conn, loads=loads, name='jsonb') psycopg2.extras.register_json(self.conn, loads=loads, name='jsonb')
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
@ -1134,7 +1160,9 @@ class JsonbTestCase(ConnectingTestCase):
def test_register_default(self): def test_register_default(self):
curs = self.conn.cursor() curs = self.conn.cursor()
loads = lambda x: psycopg2.extras.json.loads(x, parse_float=Decimal) def loads(s):
return psycopg2.extras.json.loads(s, parse_float=Decimal)
psycopg2.extras.register_default_jsonb(curs, loads=loads) psycopg2.extras.register_default_jsonb(curs, loads=loads)
curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""") curs.execute("""select '{"a": 100.0, "b": null}'::jsonb""")
@ -1200,7 +1228,7 @@ class RangeTestCase(unittest.TestCase):
('[)', True, False), ('[)', True, False),
('(]', False, True), ('(]', False, True),
('()', False, False), ('()', False, False),
('[]', True, True),]: ('[]', True, True)]:
r = Range(10, 20, bounds) r = Range(10, 20, bounds)
self.assertEqual(r.lower, 10) self.assertEqual(r.lower, 10)
self.assertEqual(r.upper, 20) self.assertEqual(r.upper, 20)
@ -1294,11 +1322,11 @@ class RangeTestCase(unittest.TestCase):
self.assert_(not Range(empty=True)) self.assert_(not Range(empty=True))
def test_eq_hash(self): def test_eq_hash(self):
from psycopg2.extras import Range
def assert_equal(r1, r2): def assert_equal(r1, r2):
self.assert_(r1 == r2) self.assert_(r1 == r2)
self.assert_(hash(r1) == hash(r2)) self.assert_(hash(r1) == hash(r2))
from psycopg2.extras import Range
assert_equal(Range(empty=True), Range(empty=True)) assert_equal(Range(empty=True), Range(empty=True))
assert_equal(Range(), Range()) assert_equal(Range(), Range())
assert_equal(Range(10, None), Range(10, None)) assert_equal(Range(10, None), Range(10, None))
@ -1321,8 +1349,11 @@ class RangeTestCase(unittest.TestCase):
def test_eq_subclass(self): def test_eq_subclass(self):
from psycopg2.extras import Range, NumericRange from psycopg2.extras import Range, NumericRange
class IntRange(NumericRange): pass class IntRange(NumericRange):
class PositiveIntRange(IntRange): pass pass
class PositiveIntRange(IntRange):
pass
self.assertEqual(Range(10, 20), IntRange(10, 20)) self.assertEqual(Range(10, 20), IntRange(10, 20))
self.assertEqual(PositiveIntRange(10, 20), IntRange(10, 20)) self.assertEqual(PositiveIntRange(10, 20), IntRange(10, 20))
@ -1480,8 +1511,8 @@ class RangeCasterTestCase(ConnectingTestCase):
r = cur.fetchone()[0] r = cur.fetchone()[0]
self.assert_(isinstance(r, DateRange)) self.assert_(isinstance(r, DateRange))
self.assert_(not r.isempty) self.assert_(not r.isempty)
self.assertEqual(r.lower, date(2000,1,2)) self.assertEqual(r.lower, date(2000, 1, 2))
self.assertEqual(r.upper, date(2012,12,31)) self.assertEqual(r.upper, date(2012, 12, 31))
self.assert_(not r.lower_inf) self.assert_(not r.lower_inf)
self.assert_(not r.upper_inf) self.assert_(not r.upper_inf)
self.assert_(r.lower_inc) self.assert_(r.lower_inc)
@ -1490,8 +1521,8 @@ class RangeCasterTestCase(ConnectingTestCase):
def test_cast_timestamp(self): def test_cast_timestamp(self):
from psycopg2.extras import DateTimeRange from psycopg2.extras import DateTimeRange
cur = self.conn.cursor() cur = self.conn.cursor()
ts1 = datetime(2000,1,1) ts1 = datetime(2000, 1, 1)
ts2 = datetime(2000,12,31,23,59,59,999) ts2 = datetime(2000, 12, 31, 23, 59, 59, 999)
cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2)) cur.execute("select tsrange(%s, %s, '()')", (ts1, ts2))
r = cur.fetchone()[0] r = cur.fetchone()[0]
self.assert_(isinstance(r, DateTimeRange)) self.assert_(isinstance(r, DateTimeRange))
@ -1507,8 +1538,9 @@ class RangeCasterTestCase(ConnectingTestCase):
from psycopg2.extras import DateTimeTZRange from psycopg2.extras import DateTimeTZRange
from psycopg2.tz import FixedOffsetTimezone from psycopg2.tz import FixedOffsetTimezone
cur = self.conn.cursor() cur = self.conn.cursor()
ts1 = datetime(2000,1,1, tzinfo=FixedOffsetTimezone(600)) ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600))
ts2 = datetime(2000,12,31,23,59,59,999, tzinfo=FixedOffsetTimezone(600)) ts2 = datetime(2000, 12, 31, 23, 59, 59, 999,
tzinfo=FixedOffsetTimezone(600))
cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2)) cur.execute("select tstzrange(%s, %s, '[]')", (ts1, ts2))
r = cur.fetchone()[0] r = cur.fetchone()[0]
self.assert_(isinstance(r, DateTimeTZRange)) self.assert_(isinstance(r, DateTimeTZRange))
@ -1598,8 +1630,9 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(isinstance(r1, DateTimeRange)) self.assert_(isinstance(r1, DateTimeRange))
self.assert_(r1.isempty) self.assert_(r1.isempty)
ts1 = datetime(2000,1,1, tzinfo=FixedOffsetTimezone(600)) ts1 = datetime(2000, 1, 1, tzinfo=FixedOffsetTimezone(600))
ts2 = datetime(2000,12,31,23,59,59,999, tzinfo=FixedOffsetTimezone(600)) ts2 = datetime(2000, 12, 31, 23, 59, 59, 999,
tzinfo=FixedOffsetTimezone(600))
r = DateTimeTZRange(ts1, ts2, '(]') r = DateTimeTZRange(ts1, ts2, '(]')
cur.execute("select %s", (r,)) cur.execute("select %s", (r,))
r1 = cur.fetchone()[0] r1 = cur.fetchone()[0]
@ -1627,7 +1660,7 @@ class RangeCasterTestCase(ConnectingTestCase):
self.assert_(not r1.lower_inc) self.assert_(not r1.lower_inc)
self.assert_(r1.upper_inc) self.assert_(r1.upper_inc)
cur.execute("select %s", ([r,r,r],)) cur.execute("select %s", ([r, r, r],))
rs = cur.fetchone()[0] rs = cur.fetchone()[0]
self.assertEqual(len(rs), 3) self.assertEqual(len(rs), 3)
for r1 in rs: for r1 in rs:
@ -1651,11 +1684,11 @@ class RangeCasterTestCase(ConnectingTestCase):
id integer primary key, id integer primary key,
range textrange)""") range textrange)""")
bounds = [ '[)', '(]', '()', '[]' ] bounds = ['[)', '(]', '()', '[]']
ranges = [ TextRange(low, up, bounds[i % 4]) ranges = [TextRange(low, up, bounds[i % 4])
for i, (low, up) in enumerate(zip( for i, (low, up) in enumerate(zip(
[None] + map(chr, range(1, 128)), [None] + map(chr, range(1, 128)),
map(chr, range(1,128)) + [None], map(chr, range(1, 128)) + [None],
))] ))]
ranges.append(TextRange()) ranges.append(TextRange())
ranges.append(TextRange(empty=True)) ranges.append(TextRange(empty=True))
@ -1736,6 +1769,6 @@ decorate_all_tests(RangeCasterTestCase, skip_if_no_range)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -30,6 +30,7 @@ import psycopg2.extensions as ext
from testutils import unittest, ConnectingTestCase from testutils import unittest, ConnectingTestCase
class WithTestCase(ConnectingTestCase): class WithTestCase(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
@ -93,7 +94,7 @@ class WithConnectionTestCase(WithTestCase):
with self.conn as conn: with self.conn as conn:
curs = conn.cursor() curs = conn.cursor()
curs.execute("insert into test_with values (3)") curs.execute("insert into test_with values (3)")
1/0 1 / 0
self.assertRaises(ZeroDivisionError, f) self.assertRaises(ZeroDivisionError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
@ -113,6 +114,7 @@ class WithConnectionTestCase(WithTestCase):
def test_subclass_commit(self): def test_subclass_commit(self):
commits = [] commits = []
class MyConn(ext.connection): class MyConn(ext.connection):
def commit(self): def commit(self):
commits.append(None) commits.append(None)
@ -131,6 +133,7 @@ class WithConnectionTestCase(WithTestCase):
def test_subclass_rollback(self): def test_subclass_rollback(self):
rollbacks = [] rollbacks = []
class MyConn(ext.connection): class MyConn(ext.connection):
def rollback(self): def rollback(self):
rollbacks.append(None) rollbacks.append(None)
@ -140,7 +143,7 @@ class WithConnectionTestCase(WithTestCase):
with self.connect(connection_factory=MyConn) as conn: with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor() curs = conn.cursor()
curs.execute("insert into test_with values (11)") curs.execute("insert into test_with values (11)")
1/0 1 / 0
except ZeroDivisionError: except ZeroDivisionError:
pass pass
else: else:
@ -175,7 +178,7 @@ class WithCursorTestCase(WithTestCase):
with self.conn as conn: with self.conn as conn:
with conn.cursor() as curs: with conn.cursor() as curs:
curs.execute("insert into test_with values (5)") curs.execute("insert into test_with values (5)")
1/0 1 / 0
except ZeroDivisionError: except ZeroDivisionError:
pass pass
@ -189,6 +192,7 @@ class WithCursorTestCase(WithTestCase):
def test_subclass(self): def test_subclass(self):
closes = [] closes = []
class MyCurs(ext.cursor): class MyCurs(ext.cursor):
def close(self): def close(self):
closes.append(None) closes.append(None)

View File

@ -69,8 +69,8 @@ else:
# Silence warnings caused by the stubbornness of the Python unittest # Silence warnings caused by the stubbornness of the Python unittest
# maintainers # maintainers
# http://bugs.python.org/issue9424 # http://bugs.python.org/issue9424
if not hasattr(unittest.TestCase, 'assert_') \ if (not hasattr(unittest.TestCase, 'assert_')
or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue: or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue):
# mavaff... # mavaff...
unittest.TestCase.assert_ = unittest.TestCase.assertTrue unittest.TestCase.assert_ = unittest.TestCase.assertTrue
unittest.TestCase.failUnless = unittest.TestCase.assertTrue unittest.TestCase.failUnless = unittest.TestCase.assertTrue
@ -175,7 +175,7 @@ def skip_if_no_uuid(f):
@wraps(f) @wraps(f)
def skip_if_no_uuid_(self): def skip_if_no_uuid_(self):
try: try:
import uuid import uuid # noqa
except ImportError: except ImportError:
return self.skipTest("uuid not available in this Python version") return self.skipTest("uuid not available in this Python version")
@ -223,7 +223,7 @@ def skip_if_no_namedtuple(f):
@wraps(f) @wraps(f)
def skip_if_no_namedtuple_(self): def skip_if_no_namedtuple_(self):
try: try:
from collections import namedtuple from collections import namedtuple # noqa
except ImportError: except ImportError:
return self.skipTest("collections.namedtuple not available") return self.skipTest("collections.namedtuple not available")
else: else:
@ -237,7 +237,7 @@ def skip_if_no_iobase(f):
@wraps(f) @wraps(f)
def skip_if_no_iobase_(self): def skip_if_no_iobase_(self):
try: try:
from io import TextIOBase from io import TextIOBase # noqa
except ImportError: except ImportError:
return self.skipTest("io.TextIOBase not found.") return self.skipTest("io.TextIOBase not found.")
else: else:
@ -249,6 +249,7 @@ def skip_if_no_iobase(f):
def skip_before_postgres(*ver): def skip_before_postgres(*ver):
"""Skip a test on PostgreSQL before a certain version.""" """Skip a test on PostgreSQL before a certain version."""
ver = ver + (0,) * (3 - len(ver)) ver = ver + (0,) * (3 - len(ver))
def skip_before_postgres_(f): def skip_before_postgres_(f):
@wraps(f) @wraps(f)
def skip_before_postgres__(self): def skip_before_postgres__(self):
@ -261,9 +262,11 @@ def skip_before_postgres(*ver):
return skip_before_postgres__ return skip_before_postgres__
return skip_before_postgres_ return skip_before_postgres_
def skip_after_postgres(*ver): def skip_after_postgres(*ver):
"""Skip a test on PostgreSQL after (including) a certain version.""" """Skip a test on PostgreSQL after (including) a certain version."""
ver = ver + (0,) * (3 - len(ver)) ver = ver + (0,) * (3 - len(ver))
def skip_after_postgres_(f): def skip_after_postgres_(f):
@wraps(f) @wraps(f)
def skip_after_postgres__(self): def skip_after_postgres__(self):
@ -276,6 +279,7 @@ def skip_after_postgres(*ver):
return skip_after_postgres__ return skip_after_postgres__
return skip_after_postgres_ return skip_after_postgres_
def libpq_version(): def libpq_version():
import psycopg2 import psycopg2
v = psycopg2.__libpq_version__ v = psycopg2.__libpq_version__
@ -283,9 +287,11 @@ def libpq_version():
v = psycopg2.extensions.libpq_version() v = psycopg2.extensions.libpq_version()
return v return v
def skip_before_libpq(*ver): def skip_before_libpq(*ver):
"""Skip a test if libpq we're linked to is older than a certain version.""" """Skip a test if libpq we're linked to is older than a certain version."""
ver = ver + (0,) * (3 - len(ver)) ver = ver + (0,) * (3 - len(ver))
def skip_before_libpq_(f): def skip_before_libpq_(f):
@wraps(f) @wraps(f)
def skip_before_libpq__(self): def skip_before_libpq__(self):
@ -298,9 +304,11 @@ def skip_before_libpq(*ver):
return skip_before_libpq__ return skip_before_libpq__
return skip_before_libpq_ return skip_before_libpq_
def skip_after_libpq(*ver): def skip_after_libpq(*ver):
"""Skip a test if libpq we're linked to is newer than a certain version.""" """Skip a test if libpq we're linked to is newer than a certain version."""
ver = ver + (0,) * (3 - len(ver)) ver = ver + (0,) * (3 - len(ver))
def skip_after_libpq_(f): def skip_after_libpq_(f):
@wraps(f) @wraps(f)
def skip_after_libpq__(self): def skip_after_libpq__(self):
@ -313,6 +321,7 @@ def skip_after_libpq(*ver):
return skip_after_libpq__ return skip_after_libpq__
return skip_after_libpq_ return skip_after_libpq_
def skip_before_python(*ver): def skip_before_python(*ver):
"""Skip a test on Python before a certain version.""" """Skip a test on Python before a certain version."""
def skip_before_python_(f): def skip_before_python_(f):
@ -327,6 +336,7 @@ def skip_before_python(*ver):
return skip_before_python__ return skip_before_python__
return skip_before_python_ return skip_before_python_
def skip_from_python(*ver): def skip_from_python(*ver):
"""Skip a test on Python after (including) a certain version.""" """Skip a test on Python after (including) a certain version."""
def skip_from_python_(f): def skip_from_python_(f):
@ -341,6 +351,7 @@ def skip_from_python(*ver):
return skip_from_python__ return skip_from_python__
return skip_from_python_ return skip_from_python_
def skip_if_no_superuser(f): def skip_if_no_superuser(f):
"""Skip a test if the database user running the test is not a superuser""" """Skip a test if the database user running the test is not a superuser"""
@wraps(f) @wraps(f)
@ -357,6 +368,7 @@ def skip_if_no_superuser(f):
return skip_if_no_superuser_ return skip_if_no_superuser_
def skip_if_green(reason): def skip_if_green(reason):
def skip_if_green_(f): def skip_if_green_(f):
@wraps(f) @wraps(f)
@ -372,6 +384,7 @@ def skip_if_green(reason):
skip_copy_if_green = skip_if_green("copy in async mode currently not supported") skip_copy_if_green = skip_if_green("copy in async mode currently not supported")
def skip_if_no_getrefcount(f): def skip_if_no_getrefcount(f):
@wraps(f) @wraps(f)
def skip_if_no_getrefcount_(self): def skip_if_no_getrefcount_(self):
@ -381,6 +394,7 @@ def skip_if_no_getrefcount(f):
return f(self) return f(self)
return skip_if_no_getrefcount_ return skip_if_no_getrefcount_
def skip_if_windows(f): def skip_if_windows(f):
"""Skip a test if run on windows""" """Skip a test if run on windows"""
@wraps(f) @wraps(f)
@ -419,6 +433,7 @@ def script_to_py3(script):
f2.close() f2.close()
os.remove(filename) os.remove(filename)
class py3_raises_typeerror(object): class py3_raises_typeerror(object):
def __enter__(self): def __enter__(self):

View File

@ -8,3 +8,8 @@ envlist = py26, py27
[testenv] [testenv]
commands = make check commands = make check
[flake8]
max-line-length = 85
ignore = E128, W503
exclude = build, doc, sandbox, examples, tests/dbapi20.py