Testing boilerplate unified in a single base class

The class makes a connection always available, allows creating
new connection and closes everything on tear down.
This commit is contained in:
Daniele Varrazzo 2013-04-07 00:23:30 +01:00
parent 408c76fdb6
commit 0e06addc9f
18 changed files with 162 additions and 271 deletions

View File

@ -32,7 +32,7 @@ import time
import select
import StringIO
from testconfig import dsn
from testutils import ConnectingTestCase
class PollableStub(object):
"""A 'pollable' wrapper allowing analysis of the `poll()` calls."""
@ -49,11 +49,13 @@ class PollableStub(object):
return rv
class AsyncTests(unittest.TestCase):
class AsyncTests(ConnectingTestCase):
def setUp(self):
self.sync_conn = psycopg2.connect(dsn)
self.conn = psycopg2.connect(dsn, async=True)
ConnectingTestCase.setUp(self)
self.sync_conn = self.conn
self.conn = self.connect(async=True)
self.wait(self.conn)
@ -64,10 +66,6 @@ class AsyncTests(unittest.TestCase):
)''')
self.wait(curs)
def tearDown(self):
self.sync_conn.close()
self.conn.close()
def wait(self, cur_or_conn):
pollable = cur_or_conn
if not hasattr(pollable, 'poll'):
@ -328,7 +326,7 @@ class AsyncTests(unittest.TestCase):
def __init__(self, dsn, async=0):
psycopg2.extensions.connection.__init__(self, dsn, async=async)
conn = psycopg2.connect(dsn, connection_factory=MyConn, async=True)
conn = self.connect(connection_factory=MyConn, async=True)
self.assert_(isinstance(conn, MyConn))
self.assert_(conn.async)
conn.close()

View File

@ -24,21 +24,12 @@
import psycopg2
import psycopg2.extensions
import time
import unittest
import gc
from testconfig import dsn
from testutils import skip_if_no_uuid
class StolenReferenceTestCase(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
from testutils import ConnectingTestCase, skip_if_no_uuid
class StolenReferenceTestCase(ConnectingTestCase):
@skip_if_no_uuid
def test_stolen_reference_bug(self):
def fish(val, cur):

View File

@ -23,7 +23,6 @@
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details.
import time
import threading
import psycopg2
@ -31,12 +30,13 @@ import psycopg2.extensions
from psycopg2 import extras
from testconfig import dsn
from testutils import unittest, skip_before_postgres
from testutils import unittest, ConnectingTestCase, skip_before_postgres
class CancelTests(unittest.TestCase):
class CancelTests(ConnectingTestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
cur = self.conn.cursor()
cur.execute('''
CREATE TEMPORARY TABLE table1 (
@ -44,9 +44,6 @@ class CancelTests(unittest.TestCase):
)''')
self.conn.commit()
def tearDown(self):
self.conn.close()
def test_empty_cancel(self):
self.conn.cancel()

View File

@ -25,24 +25,19 @@
import os
import time
import threading
from testutils import unittest, decorate_all_tests
from testutils import skip_before_postgres, skip_after_postgres, skip_if_no_superuser
from operator import attrgetter
import psycopg2
import psycopg2.errorcodes
import psycopg2.extensions
from testutils import unittest, decorate_all_tests, skip_if_no_superuser
from testutils import skip_before_postgres, skip_after_postgres
from testutils import ConnectingTestCase, skip_if_tpc_disabled
from testconfig import dsn, dbname
class ConnectionTests(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
if not self.conn.closed:
self.conn.close()
class ConnectionTests(ConnectingTestCase):
def test_closed_attribute(self):
conn = self.conn
self.assertEqual(conn.closed, False)
@ -153,7 +148,7 @@ class ConnectionTests(unittest.TestCase):
@skip_before_postgres(8, 2)
def test_concurrent_execution(self):
def slave():
cnn = psycopg2.connect(dsn)
cnn = self.connect()
cur = cnn.cursor()
cur.execute("select pg_sleep(4)")
cur.close()
@ -183,7 +178,7 @@ class ConnectionTests(unittest.TestCase):
oldenc = os.environ.get('PGCLIENTENCODING')
os.environ['PGCLIENTENCODING'] = 'utf-8' # malformed spelling
try:
self.conn = psycopg2.connect(dsn)
self.conn = self.connect()
finally:
if oldenc is not None:
os.environ['PGCLIENTENCODING'] = oldenc
@ -230,10 +225,11 @@ class ConnectionTests(unittest.TestCase):
self.assert_(not notices, "%d notices raised" % len(notices))
class IsolationLevelsTestCase(unittest.TestCase):
class IsolationLevelsTestCase(ConnectingTestCase):
def setUp(self):
self._conns = []
ConnectingTestCase.setUp(self)
conn = self.connect()
cur = conn.cursor()
try:
@ -244,17 +240,6 @@ class IsolationLevelsTestCase(unittest.TestCase):
conn.commit()
conn.close()
def tearDown(self):
# close the connections used in the test
for conn in self._conns:
if not conn.closed:
conn.close()
def connect(self):
conn = psycopg2.connect(dsn)
self._conns.append(conn)
return conn
def test_isolation_level(self):
conn = self.connect()
self.assertEqual(
@ -420,20 +405,16 @@ class IsolationLevelsTestCase(unittest.TestCase):
cnn.set_isolation_level, 1)
class ConnectionTwoPhaseTests(unittest.TestCase):
class ConnectionTwoPhaseTests(ConnectingTestCase):
def setUp(self):
self._conns = []
ConnectingTestCase.setUp(self)
self.make_test_table()
self.clear_test_xacts()
def tearDown(self):
self.clear_test_xacts()
# close the connections used in the test
for conn in self._conns:
if not conn.closed:
conn.close()
ConnectingTestCase.tearDown(self)
def clear_test_xacts(self):
"""Rollback all the prepared transaction in the testing db."""
@ -486,11 +467,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
cnn.close()
return rv
def connect(self, **kwargs):
conn = psycopg2.connect(dsn, **kwargs)
self._conns.append(conn)
return conn
def test_tpc_commit(self):
cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual")
@ -802,18 +778,10 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(None, xid.bqual)
from testutils import skip_if_tpc_disabled
decorate_all_tests(ConnectionTwoPhaseTests, skip_if_tpc_disabled)
class TransactionControlTests(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
if not self.conn.closed:
self.conn.close()
class TransactionControlTests(ConnectingTestCase):
def test_closed(self):
self.conn.close()
self.assertRaises(psycopg2.InterfaceError,
@ -955,14 +923,7 @@ class TransactionControlTests(unittest.TestCase):
self.conn.set_session, readonly=True, deferrable=True)
class AutocommitTests(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
if not self.conn.closed:
self.conn.close()
class AutocommitTests(ConnectingTestCase):
def test_closed(self):
self.conn.close()
self.assertRaises(psycopg2.InterfaceError,

View File

@ -24,13 +24,13 @@
import sys
import string
from testutils import unittest, decorate_all_tests, skip_if_no_iobase
from testutils import unittest, ConnectingTestCase, decorate_all_tests
from testutils import skip_if_no_iobase
from cStringIO import StringIO
from itertools import cycle, izip
import psycopg2
import psycopg2.extensions
from testconfig import dsn
from testutils import skip_copy_if_green
if sys.version_info[0] < 3:
@ -58,10 +58,10 @@ class MinimalWrite(_base):
return self.f.write(data)
class CopyTests(unittest.TestCase):
class CopyTests(ConnectingTestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
self._create_temp_table()
def _create_temp_table(self):
@ -72,9 +72,6 @@ class CopyTests(unittest.TestCase):
data text
)''')
def tearDown(self):
self.conn.close()
def test_copy_from(self):
curs = self.conn.cursor()
try:

View File

@ -26,16 +26,10 @@ import time
import psycopg2
import psycopg2.extensions
from psycopg2.extensions import b
from testconfig import dsn
from testutils import unittest, skip_before_postgres, skip_if_no_namedtuple
from testutils import unittest, ConnectingTestCase, skip_before_postgres
from testutils import skip_if_no_namedtuple
class CursorTests(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class CursorTests(ConnectingTestCase):
def test_close_idempotent(self):
cur = self.conn.cursor()

View File

@ -23,10 +23,9 @@
# License for more details.
import math
import unittest
import psycopg2
from psycopg2.tz import FixedOffsetTimezone, ZERO
from testconfig import dsn
from testutils import unittest, ConnectingTestCase
class CommonDatetimeTestsMixin:
@ -93,20 +92,17 @@ class CommonDatetimeTestsMixin:
self.assertEqual(value, None)
class DatetimeTests(unittest.TestCase, CommonDatetimeTestsMixin):
class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
"""Tests for the datetime based date handling in psycopg2."""
def setUp(self):
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
self.curs = self.conn.cursor()
self.DATE = psycopg2.extensions.PYDATE
self.TIME = psycopg2.extensions.PYTIME
self.DATETIME = psycopg2.extensions.PYDATETIME
self.INTERVAL = psycopg2.extensions.PYINTERVAL
def tearDown(self):
self.conn.close()
def test_parse_bc_date(self):
# datetime does not support BC dates
self.assertRaises(ValueError, self.DATE, '00042-01-01 BC', self.curs)
@ -311,11 +307,11 @@ if not hasattr(psycopg2.extensions, 'PYDATETIME'):
del DatetimeTests
class mxDateTimeTests(unittest.TestCase, CommonDatetimeTestsMixin):
class mxDateTimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
"""Tests for the mx.DateTime based date handling in psycopg2."""
def setUp(self):
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
self.curs = self.conn.cursor()
self.DATE = psycopg2._psycopg.MXDATE
self.TIME = psycopg2._psycopg.MXTIME
@ -557,6 +553,7 @@ class FixedOffsetTimezoneTests(unittest.TestCase):
self.assertEqual(tz11, tz21)
self.assertEqual(tz12, tz22)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)

View File

@ -18,26 +18,23 @@ import time
from datetime import timedelta
import psycopg2
import psycopg2.extras
from testutils import unittest, skip_before_postgres, skip_if_no_namedtuple
from testconfig import dsn
from testutils import unittest, ConnectingTestCase, skip_before_postgres
from testutils import skip_if_no_namedtuple
class ExtrasDictCursorTests(unittest.TestCase):
class ExtrasDictCursorTests(ConnectingTestCase):
"""Test if DictCursor extension class works."""
def setUp(self):
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
curs = self.conn.cursor()
curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)")
curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')")
self.conn.commit()
def tearDown(self):
self.conn.close()
def testDictConnCursorArgs(self):
self.conn.close()
self.conn = psycopg2.connect(dsn, connection_factory=psycopg2.extras.DictConnection)
self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection)
cur = self.conn.cursor()
self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
self.assertEqual(cur.name, None)
@ -232,18 +229,17 @@ class ExtrasDictCursorTests(unittest.TestCase):
self.assertEqual(r._column_mapping, r1._column_mapping)
class NamedTupleCursorTest(unittest.TestCase):
class NamedTupleCursorTest(ConnectingTestCase):
def setUp(self):
ConnectingTestCase.setUp(self)
from psycopg2.extras import NamedTupleConnection
try:
from collections import namedtuple
except ImportError:
self.conn = None
return
self.conn = psycopg2.connect(dsn,
connection_factory=NamedTupleConnection)
self.conn = self.connect(connection_factory=NamedTupleConnection)
curs = self.conn.cursor()
curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)")
curs.execute("INSERT INTO nttest VALUES (1, 'foo')")
@ -251,10 +247,6 @@ class NamedTupleCursorTest(unittest.TestCase):
curs.execute("INSERT INTO nttest VALUES (3, 'baz')")
self.conn.commit()
def tearDown(self):
if self.conn is not None:
self.conn.close()
@skip_if_no_namedtuple
def test_cursor_args(self):
cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.DictCursor)
@ -359,9 +351,7 @@ class NamedTupleCursorTest(unittest.TestCase):
# an import error somewhere
from psycopg2.extras import NamedTupleConnection
try:
if self.conn is not None:
self.conn.close()
self.conn = psycopg2.connect(dsn,
self.conn = self.connect(
connection_factory=NamedTupleConnection)
curs = self.conn.cursor()
curs.execute("select 1")
@ -371,8 +361,7 @@ class NamedTupleCursorTest(unittest.TestCase):
else:
self.fail("expecting ImportError")
else:
# skip the test
pass
return self.skipTest("namedtuple available")
@skip_if_no_namedtuple
def test_record_updated(self):

View File

@ -26,7 +26,8 @@ import unittest
import psycopg2
import psycopg2.extensions
import psycopg2.extras
from testconfig import dsn
from testutils import ConnectingTestCase
class ConnectionStub(object):
"""A `connection` wrapper allowing analysis of the `poll()` calls."""
@ -42,14 +43,14 @@ class ConnectionStub(object):
self.polls.append(rv)
return rv
class GreenTests(unittest.TestCase):
class GreenTestCase(ConnectingTestCase):
def setUp(self):
self._cb = psycopg2.extensions.get_wait_callback()
psycopg2.extensions.set_wait_callback(psycopg2.extras.wait_select)
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
def tearDown(self):
self.conn.close()
ConnectingTestCase.tearDown(self)
psycopg2.extensions.set_wait_callback(self._cb)
def set_stub_wait_callback(self, conn):

View File

@ -30,9 +30,8 @@ from functools import wraps
import psycopg2
import psycopg2.extensions
from psycopg2.extensions import b
from testconfig import dsn
from testutils import unittest, decorate_all_tests, skip_if_tpc_disabled
from testutils import skip_if_green
from testutils import ConnectingTestCase, skip_if_green
def skip_if_no_lo(f):
@wraps(f)
@ -47,10 +46,9 @@ def skip_if_no_lo(f):
skip_lo_if_green = skip_if_green("libpq doesn't support LO in async mode")
class LargeObjectMixin(object):
# doesn't derive from TestCase to avoid repeating tests twice.
class LargeObjectTestCase(ConnectingTestCase):
def setUp(self):
self.conn = self.connect()
ConnectingTestCase.setUp(self)
self.lo_oid = None
self.tmpdir = None
@ -69,13 +67,11 @@ class LargeObjectMixin(object):
pass
else:
lo.unlink()
self.conn.close()
def connect(self):
return psycopg2.connect(dsn)
ConnectingTestCase.tearDown(self)
class LargeObjectTests(LargeObjectMixin, unittest.TestCase):
class LargeObjectTests(LargeObjectTestCase):
def test_create(self):
lo = self.conn.lobject()
self.assertNotEqual(lo, None)
@ -374,8 +370,7 @@ class LargeObjectTests(LargeObjectMixin, unittest.TestCase):
self.conn.tpc_commit()
decorate_all_tests(LargeObjectTests, skip_if_no_lo)
decorate_all_tests(LargeObjectTests, skip_lo_if_green)
decorate_all_tests(LargeObjectTests, skip_if_no_lo, skip_lo_if_green)
def skip_if_no_truncate(f):
@ -394,7 +389,7 @@ def skip_if_no_truncate(f):
return skip_if_no_truncate_
class LargeObjectTruncateTests(LargeObjectMixin, unittest.TestCase):
class LargeObjectTruncateTests(LargeObjectTestCase):
def test_truncate(self):
lo = self.conn.lobject()
lo.write("some data")
@ -430,9 +425,8 @@ class LargeObjectTruncateTests(LargeObjectMixin, unittest.TestCase):
self.assertRaises(psycopg2.ProgrammingError, lo.truncate)
decorate_all_tests(LargeObjectTruncateTests, skip_if_no_lo)
decorate_all_tests(LargeObjectTruncateTests, skip_lo_if_green)
decorate_all_tests(LargeObjectTruncateTests, skip_if_no_truncate)
decorate_all_tests(LargeObjectTruncateTests,
skip_if_no_lo, skip_lo_if_green, skip_if_no_truncate)
def test_suite():

View File

@ -23,9 +23,7 @@
# License for more details.
from testutils import unittest, skip_before_python, skip_before_postgres
from testutils import skip_copy_if_green
from testconfig import dsn
from testutils import ConnectingTestCase, skip_copy_if_green
import psycopg2
@ -138,13 +136,7 @@ class ConnectTestCase(unittest.TestCase):
psycopg2.connect, 'dbname=foo', no_such_param='meh')
class ExceptionsTestCase(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class ExceptionsTestCase(ConnectingTestCase):
def test_attributes(self):
cur = self.conn.cursor()
try:

View File

@ -26,23 +26,16 @@ from testutils import unittest
import psycopg2
from psycopg2 import extensions
from testutils import ConnectingTestCase, script_to_py3
from testconfig import dsn
from testutils import script_to_py3
import sys
import time
import select
import signal
from subprocess import Popen, PIPE
class NotifiesTests(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class NotifiesTests(ConnectingTestCase):
def autocommit(self, conn):
"""Set a connection in autocommit mode."""

View File

@ -23,14 +23,13 @@
# License for more details.
import sys
from testutils import unittest
from testconfig import dsn
from testutils import unittest, ConnectingTestCase
import psycopg2
import psycopg2.extensions
from psycopg2.extensions import b
class QuotingTestCase(unittest.TestCase):
class QuotingTestCase(ConnectingTestCase):
r"""Checks the correct quoting of strings and binary objects.
Since ver. 8.1, PostgreSQL is moving towards SQL standard conforming
@ -48,12 +47,6 @@ class QuotingTestCase(unittest.TestCase):
http://www.postgresql.org/docs/current/static/sql-syntax-lexical.html#SQL-SYNTAX-STRINGS
http://www.postgresql.org/docs/current/static/runtime-config-compatible.html
"""
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
def test_string(self):
data = """some data with \t chars
to escape into, 'quotes' and \\ a backslash too.
@ -162,13 +155,7 @@ class QuotingTestCase(unittest.TestCase):
self.assert_(not self.conn.notices)
class TestQuotedString(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class TestQuotedString(ConnectingTestCase):
def test_encoding(self):
q = psycopg2.extensions.QuotedString('hi')
self.assertEqual(q.encoding, 'latin1')

View File

@ -23,17 +23,16 @@
# License for more details.
import threading
from testutils import unittest, skip_before_postgres
from testutils import unittest, ConnectingTestCase, skip_before_postgres
import psycopg2
from psycopg2.extensions import (
ISOLATION_LEVEL_SERIALIZABLE, STATUS_BEGIN, STATUS_READY)
from testconfig import dsn
class TransactionTests(unittest.TestCase):
class TransactionTests(ConnectingTestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
curs = self.conn.cursor()
curs.execute('''
@ -51,9 +50,6 @@ class TransactionTests(unittest.TestCase):
curs.execute('INSERT INTO table2 VALUES (1, 1)')
self.conn.commit()
def tearDown(self):
self.conn.close()
def test_rollback(self):
# Test that rollback undoes changes
curs = self.conn.cursor()
@ -93,16 +89,17 @@ class TransactionTests(unittest.TestCase):
self.assertEqual(curs.fetchone()[0], 1)
class DeadlockSerializationTests(unittest.TestCase):
class DeadlockSerializationTests(ConnectingTestCase):
"""Test deadlock and serialization failure errors."""
def connect(self):
conn = psycopg2.connect(dsn)
conn = ConnectingTestCase.connect(self)
conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
return conn
def setUp(self):
self.conn = self.connect()
ConnectingTestCase.setUp(self)
curs = self.conn.cursor()
# Drop table if it already exists
try:
@ -130,7 +127,8 @@ class DeadlockSerializationTests(unittest.TestCase):
curs.execute("DROP TABLE table1")
curs.execute("DROP TABLE table2")
self.conn.commit()
self.conn.close()
ConnectingTestCase.tearDown(self)
def test_deadlock(self):
self.thread1_error = self.thread2_error = None
@ -226,16 +224,13 @@ class DeadlockSerializationTests(unittest.TestCase):
error, psycopg2.extensions.TransactionRollbackError))
class QueryCancellationTests(unittest.TestCase):
class QueryCancellationTests(ConnectingTestCase):
"""Tests for query cancellation."""
def setUp(self):
self.conn = psycopg2.connect(dsn)
ConnectingTestCase.setUp(self)
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
def tearDown(self):
self.conn.close()
@skip_before_postgres(8, 2)
def test_statement_timeout(self):
curs = self.conn.cursor()

View File

@ -27,22 +27,15 @@ import decimal
import sys
from functools import wraps
import testutils
from testutils import unittest, decorate_all_tests
from testconfig import dsn
from testutils import unittest, ConnectingTestCase, decorate_all_tests
import psycopg2
from psycopg2.extensions import b
class TypesBasicTests(unittest.TestCase):
class TypesBasicTests(ConnectingTestCase):
"""Test that all type conversions are working."""
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
def execute(self, *args):
curs = self.conn.cursor()
curs.execute(*args)

View File

@ -21,14 +21,12 @@ from datetime import date, datetime
from functools import wraps
from testutils import unittest, skip_if_no_uuid, skip_before_postgres
from testutils import decorate_all_tests
from testutils import ConnectingTestCase, decorate_all_tests
import psycopg2
import psycopg2.extras
from psycopg2.extensions import b
from testconfig import dsn
def filter_scs(conn, s):
if conn.get_parameter_status("standard_conforming_strings") == 'off':
@ -36,15 +34,9 @@ def filter_scs(conn, s):
else:
return s.replace(b("E'"), b("'"))
class TypesExtrasTests(unittest.TestCase):
class TypesExtrasTests(ConnectingTestCase):
"""Test that all type conversions are working."""
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
def execute(self, *args):
curs = self.conn.cursor()
curs.execute(*args)
@ -135,13 +127,7 @@ def skip_if_no_hstore(f):
return skip_if_no_hstore_
class HstoreTestCase(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class HstoreTestCase(ConnectingTestCase):
def test_adapt_8(self):
if self.conn.server_version >= 90000:
return self.skipTest("skipping dict adaptation with PG pre-9 syntax")
@ -276,7 +262,7 @@ class HstoreTestCase(unittest.TestCase):
oids = HstoreAdapter.get_oids(self.conn)
try:
register_hstore(self.conn, globally=True)
conn2 = psycopg2.connect(dsn)
conn2 = self.connect()
try:
cur2 = self.conn.cursor()
cur2.execute("select 'a => b'::hstore")
@ -429,7 +415,7 @@ class HstoreTestCase(unittest.TestCase):
from psycopg2.extras import RealDictConnection
from psycopg2.extras import register_hstore
conn = psycopg2.connect(dsn, connection_factory=RealDictConnection)
conn = self.connect(connection_factory=RealDictConnection)
try:
register_hstore(conn)
curs = conn.cursor()
@ -438,7 +424,7 @@ class HstoreTestCase(unittest.TestCase):
finally:
conn.close()
conn = psycopg2.connect(dsn, connection_factory=RealDictConnection)
conn = self.connect(connection_factory=RealDictConnection)
try:
curs = conn.cursor()
register_hstore(curs)
@ -460,13 +446,7 @@ def skip_if_no_composite(f):
return skip_if_no_composite_
class AdaptTypeTestCase(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class AdaptTypeTestCase(ConnectingTestCase):
@skip_if_no_composite
def test_none_in_record(self):
curs = self.conn.cursor()
@ -621,8 +601,8 @@ class AdaptTypeTestCase(unittest.TestCase):
def test_register_on_connection(self):
self._create_type("type_ii", [("a", "integer"), ("b", "integer")])
conn1 = psycopg2.connect(dsn)
conn2 = psycopg2.connect(dsn)
conn1 = self.connect()
conn2 = self.connect()
try:
psycopg2.extras.register_composite("type_ii", conn1)
curs1 = conn1.cursor()
@ -639,8 +619,8 @@ class AdaptTypeTestCase(unittest.TestCase):
def test_register_globally(self):
self._create_type("type_ii", [("a", "integer"), ("b", "integer")])
conn1 = psycopg2.connect(dsn)
conn2 = psycopg2.connect(dsn)
conn1 = self.connect()
conn2 = self.connect()
try:
t = psycopg2.extras.register_composite("type_ii", conn1, globally=True)
try:
@ -765,7 +745,7 @@ class AdaptTypeTestCase(unittest.TestCase):
from psycopg2.extras import register_composite
self._create_type("type_ii", [("a", "integer"), ("b", "integer")])
conn = psycopg2.connect(dsn, connection_factory=RealDictConnection)
conn = self.connect(connection_factory=RealDictConnection)
try:
register_composite('type_ii', conn)
curs = conn.cursor()
@ -774,7 +754,7 @@ class AdaptTypeTestCase(unittest.TestCase):
finally:
conn.close()
conn = psycopg2.connect(dsn, connection_factory=RealDictConnection)
conn = self.connect(connection_factory=RealDictConnection)
try:
curs = conn.cursor()
register_composite('type_ii', conn)
@ -867,13 +847,7 @@ def skip_if_no_json_type(f):
return skip_if_no_json_type_
class JsonTestCase(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class JsonTestCase(ConnectingTestCase):
@skip_if_json_module
def test_module_not_available(self):
from psycopg2.extras import Json
@ -1259,12 +1233,7 @@ def skip_if_no_range(f):
return skip_if_no_range_
class RangeCasterTestCase(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
def tearDown(self):
self.conn.close()
class RangeCasterTestCase(ConnectingTestCase):
builtin_ranges = ('int4range', 'int8range', 'numrange',
'daterange', 'tsrange', 'tstzrange')

View File

@ -28,27 +28,23 @@ from __future__ import with_statement
import psycopg2
import psycopg2.extensions as ext
from testconfig import dsn
from testutils import unittest
from testutils import unittest, ConnectingTestCase
class TestMixin(object):
class WithTestCase(ConnectingTestCase):
def setUp(self):
self.conn = conn = psycopg2.connect(dsn)
curs = conn.cursor()
ConnectingTestCase.setUp(self)
curs = self.conn.cursor()
try:
curs.execute("delete from test_with")
conn.commit()
self.conn.commit()
except psycopg2.ProgrammingError:
# assume table doesn't exist
conn.rollback()
self.conn.rollback()
curs.execute("create table test_with (id integer primary key)")
conn.commit()
def tearDown(self):
self.conn.close()
self.conn.commit()
class WithConnectionTestCase(TestMixin, unittest.TestCase):
class WithConnectionTestCase(WithTestCase):
def test_with_ok(self):
with self.conn as conn:
self.assert_(self.conn is conn)
@ -65,7 +61,7 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
self.assertEqual(curs.fetchall(), [(1,)])
def test_with_connect_idiom(self):
with psycopg2.connect(dsn) as conn:
with self.connect() as conn:
self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor()
curs.execute("insert into test_with values (2)")
@ -122,7 +118,7 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
commits.append(None)
super(MyConn, self).commit()
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (10)")
@ -141,7 +137,7 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
super(MyConn, self).rollback()
try:
with psycopg2.connect(dsn, connection_factory=MyConn) as conn:
with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor()
curs.execute("insert into test_with values (11)")
1/0
@ -158,7 +154,7 @@ class WithConnectionTestCase(TestMixin, unittest.TestCase):
self.assertEqual(curs.fetchall(), [])
class WithCursorTestCase(TestMixin, unittest.TestCase):
class WithCursorTestCase(WithTestCase):
def test_with_ok(self):
with self.conn as conn:
with conn.cursor() as curs:

View File

@ -27,6 +27,7 @@
import os
import sys
from functools import wraps
from testconfig import dsn
try:
import unittest2
@ -74,11 +75,57 @@ or unittest.TestCase.assert_ is not unittest.TestCase.assertTrue:
unittest.TestCase.failUnlessEqual = unittest.TestCase.assertEqual
def decorate_all_tests(cls, decorator):
"""Apply *decorator* to all the tests defined in the TestCase *cls*."""
class ConnectingTestCase(unittest.TestCase):
"""A test case providing connections for tests.
A connection for the test is always available as `self.conn`. Others can be
created with `self.connect()`. All are closed on tearDown.
Subclasses needing to customize setUp and tearDown should remember to call
the base class implementations.
"""
def setUp(self):
self._conns = []
def tearDown(self):
# close the connections used in the test
for conn in self._conns:
if not conn.closed:
conn.close()
def connect(self, **kwargs):
try:
self._conns
except AttributeError, e:
raise AttributeError(
"%s (did you remember calling ConnectingTestCase.setUp()?)"
% e)
import psycopg2
conn = psycopg2.connect(dsn, **kwargs)
self._conns.append(conn)
return conn
def _get_conn(self):
if not hasattr(self, '_the_conn'):
self._the_conn = self.connect()
return self._the_conn
def _set_conn(self, conn):
self._the_conn = conn
conn = property(_get_conn, _set_conn)
def decorate_all_tests(cls, *decorators):
"""
Apply all the *decorators* to all the tests defined in the TestCase *cls*.
"""
for n in dir(cls):
if n.startswith('test'):
setattr(cls, n, decorator(getattr(cls, n)))
for d in decorators:
setattr(cls, n, d(getattr(cls, n)))
def skip_if_no_uuid(f):