Test cleanup.

Tests pass or fail gracefully on older PostgreSQL versions.

If unittest2 is available, skip tests instead of printing warnings.
This commit is contained in:
Daniele Varrazzo 2010-11-19 03:55:37 +00:00
parent 94348bfb78
commit 19ead4a5cb
14 changed files with 223 additions and 123 deletions

View File

@ -71,7 +71,7 @@ docs-txt: doc/psycopg2.txt
sdist: $(SDIST) sdist: $(SDIST)
runtests: package runtests: package
PSYCOPG2_TESTDB=$(TESTDB) PYTHONPATH=$(BUILD_DIR):. $(PYTHON) tests/__init__.py --verbose PSYCOPG2_TESTDB=$(TESTDB) PYTHONPATH=$(BUILD_DIR):.:$(PYTHONPATH) $(PYTHON) tests/__init__.py --verbose
# The environment is currently required to build the documentation. # The environment is currently required to build the documentation.

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import unittest from testutils import unittest
dbname = os.environ.get('PSYCOPG2_TESTDB', 'psycopg2_test') dbname = os.environ.get('PSYCOPG2_TESTDB', 'psycopg2_test')
dbhost = os.environ.get('PSYCOPG2_TESTDB_HOST', None) dbhost = os.environ.get('PSYCOPG2_TESTDB_HOST', None)
@ -59,14 +59,8 @@ def test_suite():
suite.addTest(test_transaction.test_suite()) suite.addTest(test_transaction.test_suite())
suite.addTest(types_basic.test_suite()) suite.addTest(types_basic.test_suite())
suite.addTest(types_extras.test_suite()) suite.addTest(types_extras.test_suite())
suite.addTest(test_lobject.test_suite())
if not green: suite.addTest(test_copy.test_suite())
suite.addTest(test_lobject.test_suite())
suite.addTest(test_copy.test_suite())
else:
import warnings
warnings.warn("copy/lobjects not implemented in green mode: skipping tests")
suite.addTest(test_notify.test_suite()) suite.addTest(test_notify.test_suite())
suite.addTest(test_async.test_suite()) suite.addTest(test_async.test_suite())
suite.addTest(test_green.test_suite()) suite.addTest(test_green.test_suite())

View File

@ -16,7 +16,7 @@
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
import unittest from testutils import unittest
import tests import tests
@ -111,8 +111,7 @@ def if_has_namedtuple(f):
try: try:
from collections import namedtuple from collections import namedtuple
except ImportError: except ImportError:
import warnings return self.skipTest("collections.namedtuple not available")
warnings.warn("collections.namedtuple not available")
else: else:
return f(self) return f(self)
@ -133,8 +132,9 @@ class NamedTupleCursorTest(unittest.TestCase):
connection_factory=NamedTupleConnection) connection_factory=NamedTupleConnection)
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)") curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)")
curs.execute( curs.execute("INSERT INTO nttest VALUES (1, 'foo')")
"INSERT INTO nttest VALUES (1, 'foo'), (2, 'bar'), (3, 'baz')") curs.execute("INSERT INTO nttest VALUES (2, 'bar')")
curs.execute("INSERT INTO nttest VALUES (3, 'baz')")
self.conn.commit() self.conn.commit()
@if_has_namedtuple @if_has_namedtuple

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest from testutils import unittest
import psycopg2 import psycopg2
from psycopg2 import extensions from psycopg2 import extensions
@ -98,14 +98,13 @@ class AsyncTests(unittest.TestCase):
cur = self.conn.cursor() cur = self.conn.cursor()
try: try:
cur.callproc("pg_sleep", (0.1, )) cur.callproc("pg_sleep", (0.1, ))
except psycopg2.ProgrammingError: self.assertTrue(self.conn.isexecuting())
# PG <8.1 did not have pg_sleep
return
self.assertTrue(self.conn.isexecuting())
self.wait(cur) self.wait(cur)
self.assertFalse(self.conn.isexecuting()) self.assertFalse(self.conn.isexecuting())
self.assertEquals(cur.fetchall()[0][0], '') self.assertEquals(cur.fetchall()[0][0], '')
except psycopg2.ProgrammingError:
return self.skipTest("PG < 8.1 did not have pg_sleep")
def test_async_after_async(self): def test_async_after_async(self):
cur = self.conn.cursor() cur = self.conn.cursor()
@ -213,7 +212,11 @@ class AsyncTests(unittest.TestCase):
cur.execute("begin") cur.execute("begin")
self.wait(cur) self.wait(cur)
cur.execute("insert into table1 values (1), (2), (3)") cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
self.wait(cur) self.wait(cur)
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
@ -247,7 +250,11 @@ class AsyncTests(unittest.TestCase):
def test_async_scroll(self): def test_async_scroll(self):
cur = self.conn.cursor() cur = self.conn.cursor()
cur.execute("insert into table1 values (1), (2), (3)") cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
self.wait(cur) self.wait(cur)
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
@ -279,7 +286,11 @@ class AsyncTests(unittest.TestCase):
def test_scroll(self): def test_scroll(self):
cur = self.sync_conn.cursor() cur = self.sync_conn.cursor()
cur.execute("create table table1 (id int)") cur.execute("create table table1 (id int)")
cur.execute("insert into table1 values (1), (2), (3)") cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
cur.scroll(2) cur.scroll(2)
cur.scroll(-1) cur.scroll(-1)

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest from testutils import unittest, decorate_all_tests
from operator import attrgetter from operator import attrgetter
import psycopg2 import psycopg2
@ -82,13 +82,24 @@ class ConnectionTests(unittest.TestCase):
conn = self.connect() conn = self.connect()
self.assert_(conn.protocol_version in (2,3), conn.protocol_version) self.assert_(conn.protocol_version in (2,3), conn.protocol_version)
def test_tpc_unsupported(self):
cnn = self.connect()
if cnn.server_version >= 80100:
return self.skipTest("tpc is supported")
self.assertRaises(psycopg2.NotSupportedError,
cnn.xid, 42, "foo", "bar")
class IsolationLevelsTestCase(unittest.TestCase): class IsolationLevelsTestCase(unittest.TestCase):
def setUp(self): def setUp(self):
conn = self.connect() conn = self.connect()
cur = conn.cursor() cur = conn.cursor()
cur.execute("drop table if exists isolevel;") try:
cur.execute("drop table isolevel;")
except psycopg2.ProgrammingError:
conn.rollback()
cur.execute("create table isolevel (id integer);") cur.execute("create table isolevel (id integer);")
conn.commit() conn.commit()
conn.close() conn.close()
@ -244,18 +255,16 @@ def skip_if_tpc_disabled(f):
try: try:
cur.execute("SHOW max_prepared_transactions;") cur.execute("SHOW max_prepared_transactions;")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
# Server version too old: let's die a different death return self.skipTest(
mtp = 1 "server too old: two phase transactions not supported.")
else: else:
mtp = int(cur.fetchone()[0]) mtp = int(cur.fetchone()[0])
cnn.close() cnn.close()
if not mtp: if not mtp:
import warnings return self.skipTest(
warnings.warn(
"server not configured for two phase transactions. " "server not configured for two phase transactions. "
"set max_prepared_transactions to > 0 to run the test") "set max_prepared_transactions to > 0 to run the test")
return
return f(self) return f(self)
skip_if_tpc_disabled_.__name__ = f.__name__ skip_if_tpc_disabled_.__name__ = f.__name__
@ -274,9 +283,15 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
cnn = self.connect() cnn = self.connect()
cnn.set_isolation_level(0) cnn.set_isolation_level(0)
cur = cnn.cursor() cur = cnn.cursor()
cur.execute( try:
"select gid from pg_prepared_xacts where database = %s", cur.execute(
(tests.dbname,)) "select gid from pg_prepared_xacts where database = %s",
(tests.dbname,))
except psycopg2.ProgrammingError:
cnn.rollback()
cnn.close()
return
gids = [ r[0] for r in cur ] gids = [ r[0] for r in cur ]
for gid in gids: for gid in gids:
cur.execute("rollback prepared %s;", (gid,)) cur.execute("rollback prepared %s;", (gid,))
@ -285,7 +300,10 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
def make_test_table(self): def make_test_table(self):
cnn = self.connect() cnn = self.connect()
cur = cnn.cursor() cur = cnn.cursor()
cur.execute("DROP TABLE IF EXISTS test_tpc;") try:
cur.execute("DROP TABLE test_tpc;")
except psycopg2.ProgrammingError:
cnn.rollback()
cur.execute("CREATE TABLE test_tpc (data text);") cur.execute("CREATE TABLE test_tpc (data text);")
cnn.commit() cnn.commit()
cnn.close() cnn.close()
@ -314,7 +332,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
def connect(self): def connect(self):
return psycopg2.connect(tests.dsn) return psycopg2.connect(tests.dsn)
@skip_if_tpc_disabled
def test_tpc_commit(self): def test_tpc_commit(self):
cnn = self.connect() cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual") xid = cnn.xid(1, "gtrid", "bqual")
@ -356,7 +373,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(1, self.count_test_records()) self.assertEqual(1, self.count_test_records())
@skip_if_tpc_disabled
def test_tpc_commit_recovered(self): def test_tpc_commit_recovered(self):
cnn = self.connect() cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual") xid = cnn.xid(1, "gtrid", "bqual")
@ -383,7 +399,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(1, self.count_test_records()) self.assertEqual(1, self.count_test_records())
@skip_if_tpc_disabled
def test_tpc_rollback(self): def test_tpc_rollback(self):
cnn = self.connect() cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual") xid = cnn.xid(1, "gtrid", "bqual")
@ -425,7 +440,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(0, self.count_xacts()) self.assertEqual(0, self.count_xacts())
self.assertEqual(0, self.count_test_records()) self.assertEqual(0, self.count_test_records())
@skip_if_tpc_disabled
def test_tpc_rollback_recovered(self): def test_tpc_rollback_recovered(self):
cnn = self.connect() cnn = self.connect()
xid = cnn.xid(1, "gtrid", "bqual") xid = cnn.xid(1, "gtrid", "bqual")
@ -464,7 +478,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
xns = cnn.tpc_recover() xns = cnn.tpc_recover()
self.assertEqual(psycopg2.extensions.STATUS_BEGIN, cnn.status) self.assertEqual(psycopg2.extensions.STATUS_BEGIN, cnn.status)
@skip_if_tpc_disabled
def test_recovered_xids(self): def test_recovered_xids(self):
# insert a few test xns # insert a few test xns
cnn = self.connect() cnn = self.connect()
@ -495,7 +508,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual(xid.owner, owner) self.assertEqual(xid.owner, owner)
self.assertEqual(xid.database, database) self.assertEqual(xid.database, database)
@skip_if_tpc_disabled
def test_xid_encoding(self): def test_xid_encoding(self):
cnn = self.connect() cnn = self.connect()
xid = cnn.xid(42, "gtrid", "bqual") xid = cnn.xid(42, "gtrid", "bqual")
@ -508,7 +520,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
(tests.dbname,)) (tests.dbname,))
self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0]) self.assertEqual('42_Z3RyaWQ=_YnF1YWw=', cur.fetchone()[0])
@skip_if_tpc_disabled
def test_xid_roundtrip(self): def test_xid_roundtrip(self):
for fid, gtrid, bqual in [ for fid, gtrid, bqual in [
(0, "", ""), (0, "", ""),
@ -532,7 +543,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
cnn.tpc_rollback(xid) cnn.tpc_rollback(xid)
@skip_if_tpc_disabled
def test_unparsed_roundtrip(self): def test_unparsed_roundtrip(self):
for tid in [ for tid in [
'', '',
@ -585,7 +595,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
x2 = Xid.from_string('99_xxx_yyy') x2 = Xid.from_string('99_xxx_yyy')
self.assertEqual(str(x2), '99_xxx_yyy') self.assertEqual(str(x2), '99_xxx_yyy')
@skip_if_tpc_disabled
def test_xid_unicode(self): def test_xid_unicode(self):
cnn = self.connect() cnn = self.connect()
x1 = cnn.xid(10, u'uni', u'code') x1 = cnn.xid(10, u'uni', u'code')
@ -598,7 +607,6 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual('uni', xid.gtrid) self.assertEqual('uni', xid.gtrid)
self.assertEqual('code', xid.bqual) self.assertEqual('code', xid.bqual)
@skip_if_tpc_disabled
def test_xid_unicode_unparsed(self): def test_xid_unicode_unparsed(self):
# We don't expect people shooting snowmen as transaction ids, # We don't expect people shooting snowmen as transaction ids,
# so if something explodes in an encode error I don't mind. # so if something explodes in an encode error I don't mind.
@ -615,6 +623,8 @@ class ConnectionTwoPhaseTests(unittest.TestCase):
self.assertEqual('transaction-id', xid.gtrid) self.assertEqual('transaction-id', xid.gtrid)
self.assertEqual(None, xid.bqual) self.assertEqual(None, xid.bqual)
decorate_all_tests(ConnectionTwoPhaseTests, skip_if_tpc_disabled)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)

View File

@ -1,7 +1,7 @@
#!/usr/bin/env python #!/usr/bin/env python
import os import os
import string import string
import unittest from testutils import unittest, decorate_all_tests
from cStringIO import StringIO from cStringIO import StringIO
from itertools import cycle, izip from itertools import cycle, izip
@ -9,6 +9,15 @@ import psycopg2
import psycopg2.extensions import psycopg2.extensions
import tests import tests
def skip_if_green(f):
def skip_if_green_(self):
if tests.green:
return self.skipTest("copy in async mode currently not supported")
else:
return f(self)
return skip_if_green_
class MinimalRead(object): class MinimalRead(object):
"""A file wrapper exposing the minimal interface to copy from.""" """A file wrapper exposing the minimal interface to copy from."""
@ -29,6 +38,7 @@ class MinimalWrite(object):
def write(self, data): def write(self, data):
return self.f.write(data) return self.f.write(data)
class CopyTests(unittest.TestCase): class CopyTests(unittest.TestCase):
def setUp(self): def setUp(self):
@ -124,6 +134,9 @@ class CopyTests(unittest.TestCase):
self.assertEqual(ntests, len(string.letters)) self.assertEqual(ntests, len(string.letters))
decorate_all_tests(CopyTests, skip_if_green)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)

View File

@ -2,16 +2,33 @@
import os import os
import shutil import shutil
import tempfile import tempfile
import unittest from testutils import unittest, decorate_all_tests
import warnings
import psycopg2 import psycopg2
import psycopg2.extensions import psycopg2.extensions
import tests import tests
def skip_if_no_lo(f):
def skip_if_no_lo_(self):
if self.conn.server_version < 80100:
return self.skipTest("large objects only supported from PG 8.1")
else:
return f(self)
class LargeObjectTests(unittest.TestCase): return skip_if_no_lo_
def skip_if_green(f):
def skip_if_green_(self):
if tests.green:
return self.skipTest("libpq doesn't support LO in async mode")
else:
return f(self)
return skip_if_green_
class LargeObjectMixin(object):
# doesn't derive from TestCase to avoid repeating tests twice.
def setUp(self): def setUp(self):
self.conn = psycopg2.connect(tests.dsn) self.conn = psycopg2.connect(tests.dsn)
self.lo_oid = None self.lo_oid = None
@ -30,6 +47,8 @@ class LargeObjectTests(unittest.TestCase):
lo.unlink() lo.unlink()
self.conn.close() self.conn.close()
class LargeObjectTests(LargeObjectMixin, unittest.TestCase):
def test_create(self): def test_create(self):
lo = self.conn.lobject() lo = self.conn.lobject()
self.assertNotEqual(lo, None) self.assertNotEqual(lo, None)
@ -261,30 +280,25 @@ class LargeObjectTests(unittest.TestCase):
self.assertTrue(os.path.exists(filename)) self.assertTrue(os.path.exists(filename))
self.assertEqual(open(filename, "rb").read(), "some data") self.assertEqual(open(filename, "rb").read(), "some data")
decorate_all_tests(LargeObjectTests, skip_if_no_lo)
decorate_all_tests(LargeObjectTests, skip_if_green)
class LargeObjectTruncateTests(LargeObjectTests):
skip = None def skip_if_no_truncate(f):
def skip_if_no_truncate_(self):
if self.conn.server_version < 80300:
return self.skipTest(
"the server doesn't support large object truncate")
def setUp(self): if not hasattr(psycopg2.extensions.lobject, 'truncate'):
LargeObjectTests.setUp(self) return self.skipTest(
"psycopg2 has been built against a libpq "
"without large object truncate support.")
if self.skip is None: return f(self)
self.skip = False
if self.conn.server_version < 80300:
warnings.warn("Large object truncate tests skipped, "
"the server does not support them")
self.skip = True
if not hasattr(psycopg2.extensions.lobject, 'truncate'):
warnings.warn("Large object truncate tests skipped, "
"psycopg2 has been built against an old library")
self.skip = True
class LargeObjectTruncateTests(LargeObjectMixin, unittest.TestCase):
def test_truncate(self): def test_truncate(self):
if self.skip:
return
lo = self.conn.lobject() lo = self.conn.lobject()
lo.write("some data") lo.write("some data")
lo.close() lo.close()
@ -308,23 +322,22 @@ class LargeObjectTruncateTests(LargeObjectTests):
self.assertEqual(lo.read(), "") self.assertEqual(lo.read(), "")
def test_truncate_after_close(self): def test_truncate_after_close(self):
if self.skip:
return
lo = self.conn.lobject() lo = self.conn.lobject()
lo.close() lo.close()
self.assertRaises(psycopg2.InterfaceError, lo.truncate) self.assertRaises(psycopg2.InterfaceError, lo.truncate)
def test_truncate_after_commit(self): def test_truncate_after_commit(self):
if self.skip:
return
lo = self.conn.lobject() lo = self.conn.lobject()
self.lo_oid = lo.oid self.lo_oid = lo.oid
self.conn.commit() self.conn.commit()
self.assertRaises(psycopg2.ProgrammingError, lo.truncate) self.assertRaises(psycopg2.ProgrammingError, lo.truncate)
decorate_all_tests(LargeObjectTruncateTests, skip_if_no_lo)
decorate_all_tests(LargeObjectTruncateTests, skip_if_green)
decorate_all_tests(LargeObjectTruncateTests, skip_if_no_truncate)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest from testutils import unittest
import warnings
import psycopg2 import psycopg2
from psycopg2 import extensions from psycopg2 import extensions
@ -127,9 +126,8 @@ conn.close()
def test_notify_payload(self): def test_notify_payload(self):
if self.conn.server_version < 90000: if self.conn.server_version < 90000:
warnings.warn("server version %s doesn't support notify payload: skipping test" return self.skipTest("server version %s doesn't support notify payload"
% self.conn.server_version) % self.conn.server_version)
return
self.autocommit(self.conn) self.autocommit(self.conn)
self.listen('foo') self.listen('foo')
pid = int(self.notify('foo', payload="Hello, world!").communicate()[0]) pid = int(self.notify('foo', payload="Hello, world!").communicate()[0])

View File

@ -2,7 +2,7 @@
import dbapi20 import dbapi20
import dbapi20_tpc import dbapi20_tpc
from test_connection import skip_if_tpc_disabled from test_connection import skip_if_tpc_disabled
import unittest from testutils import unittest, decorate_all_tests
import psycopg2 import psycopg2
import tests import tests
@ -23,19 +23,14 @@ class Psycopg2Tests(dbapi20.DatabaseAPI20Test):
pass pass
class Psycopg2TPCTests(dbapi20_tpc.TwoPhaseCommitTests): class Psycopg2TPCTests(dbapi20_tpc.TwoPhaseCommitTests, unittest.TestCase):
driver = psycopg2 driver = psycopg2
def connect(self): def connect(self):
return psycopg2.connect(dsn=tests.dsn) return psycopg2.connect(dsn=tests.dsn)
@skip_if_tpc_disabled decorate_all_tests(Psycopg2TPCTests, skip_if_tpc_disabled)
def test_tpc_commit_with_prepare(self):
super(Psycopg2TPCTests, self).test_tpc_commit_with_prepare()
@skip_if_tpc_disabled
def test_tpc_rollback_with_prepare(self):
super(Psycopg2TPCTests, self).test_tpc_rollback_with_prepare()
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)

View File

@ -1,6 +1,5 @@
#!/usr/bin/env python #!/usr/bin/env python
import unittest from testutils import unittest
import warnings
import psycopg2 import psycopg2
import psycopg2.extensions import psycopg2.extensions
@ -61,9 +60,9 @@ class QuotingTestCase(unittest.TestCase):
curs.execute("SHOW server_encoding") curs.execute("SHOW server_encoding")
server_encoding = curs.fetchone()[0] server_encoding = curs.fetchone()[0]
if server_encoding != "UTF8": if server_encoding != "UTF8":
warnings.warn("Unicode test skipped since server encoding is %s" return self.skipTest(
% server_encoding) "Unicode test skipped since server encoding is %s"
return % server_encoding)
data = u"""some data with \t chars data = u"""some data with \t chars
to escape into, 'quotes', \u20ac euro sign and \\ a backslash too. to escape into, 'quotes', \u20ac euro sign and \\ a backslash too.

View File

@ -1,6 +1,6 @@
#!/usr/bin/env python #!/usr/bin/env python
import threading import threading
import unittest from testutils import unittest
import psycopg2 import psycopg2
from psycopg2.extensions import ( from psycopg2.extensions import (
@ -215,8 +215,11 @@ class QueryCancellationTests(unittest.TestCase):
curs = self.conn.cursor() curs = self.conn.cursor()
# Set a low statement timeout, then sleep for a longer period. # Set a low statement timeout, then sleep for a longer period.
curs.execute('SET statement_timeout TO 10') curs.execute('SET statement_timeout TO 10')
self.assertRaises(psycopg2.extensions.QueryCanceledError, try:
curs.execute, 'SELECT pg_sleep(50)') self.assertRaises(psycopg2.extensions.QueryCanceledError,
curs.execute, 'SELECT pg_sleep(50)')
except psycopg2.ProgrammingError:
return self.skipTest("pg_sleep not available")
def test_suite(): def test_suite():

46
tests/testutils.py Normal file
View File

@ -0,0 +1,46 @@
# Utility module for psycopg2 testing.
#
# Copyright (C) 2010 Daniele Varrazzo <daniele.varrazzo@gmail.com>
# Use unittest2 if available. Otherwise mock a skip facility with warnings.
try:
import unittest2
unittest = unittest2
except ImportError:
import unittest
unittest2 = None
if hasattr(unittest, 'skipIf'):
from unittest2 import skip, skipIf
else:
import warnings
def skipIf(cond, msg):
def skipIf_(f):
def skipIf__(self):
if cond:
warnings.warn(msg)
return
else:
return f(self)
return skipIf__
return skipIf_
def skip(msg):
return skipIf(True, msg)
def skipTest(self, msg):
warnings.warn(msg)
return
unittest.TestCase.skipTest = skipTest
def decorate_all_tests(cls, decorator):
"""Apply *decorator* to all the tests defined in the TestCase *cls*."""
for n in dir(cls):
if n.startswith('test'):
setattr(cls, n, decorator(getattr(cls, n)))

View File

@ -27,7 +27,7 @@ try:
except: except:
pass pass
import sys import sys
import unittest from testutils import unittest
import psycopg2 import psycopg2
import tests import tests
@ -78,14 +78,14 @@ class TypesBasicTests(unittest.TestCase):
s = self.execute("SELECT %s AS foo", (decimal.Decimal("-infinity"),)) s = self.execute("SELECT %s AS foo", (decimal.Decimal("-infinity"),))
self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s)) self.failUnless(str(s) == "NaN", "wrong decimal quoting: " + str(s))
self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s)) self.failUnless(type(s) == decimal.Decimal, "wrong decimal conversion: " + repr(s))
else:
return self.skipTest("decimal not available")
def testFloat(self): def testFloat(self):
try: try:
float("nan") float("nan")
except ValueError: except ValueError:
import warnings return self.skipTest("nan not available on this platform")
warnings.warn("nan not available on this platform")
return
s = self.execute("SELECT %s AS foo", (float("nan"),)) s = self.execute("SELECT %s AS foo", (float("nan"),))
self.failUnless(str(s) == "nan", "wrong float quoting: " + str(s)) self.failUnless(str(s) == "nan", "wrong float quoting: " + str(s))

View File

@ -20,14 +20,35 @@ except:
pass pass
import re import re
import sys import sys
import unittest from testutils import unittest
import warnings
import psycopg2 import psycopg2
import psycopg2.extras import psycopg2.extras
import tests import tests
def skip_if_no_uuid(f):
def skip_if_no_uuid_(self):
try:
cur = self.conn.cursor()
cur.execute("select typname from pg_type where typname = 'uuid'")
has = cur.fetchone()
finally:
self.conn.rollback()
if has:
return f(self)
else:
return self.skipTest("uuid type not available")
return skip_if_no_uuid_
def filter_scs(conn, s):
if conn.get_parameter_status("standard_conforming_strings") == 'off':
return s
else:
return s.replace("E'", "'")
class TypesExtrasTests(unittest.TestCase): class TypesExtrasTests(unittest.TestCase):
"""Test that all type conversions are working.""" """Test that all type conversions are working."""
@ -39,12 +60,10 @@ class TypesExtrasTests(unittest.TestCase):
curs.execute(*args) curs.execute(*args)
return curs.fetchone()[0] return curs.fetchone()[0]
@skip_if_no_uuid
def testUUID(self): def testUUID(self):
try: import uuid
import uuid psycopg2.extras.register_uuid()
psycopg2.extras.register_uuid()
except:
return
u = uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350') u = uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350')
s = self.execute("SELECT %s AS foo", (u,)) s = self.execute("SELECT %s AS foo", (u,))
self.failUnless(u == s) self.failUnless(u == s)
@ -52,12 +71,10 @@ class TypesExtrasTests(unittest.TestCase):
s = self.execute("SELECT NULL::uuid AS foo") s = self.execute("SELECT NULL::uuid AS foo")
self.failUnless(s is None) self.failUnless(s is None)
@skip_if_no_uuid
def testUUIDARRAY(self): def testUUIDARRAY(self):
try: import uuid
import uuid psycopg2.extras.register_uuid()
psycopg2.extras.register_uuid()
except:
return
u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'), uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e352')] u = [uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e350'), uuid.UUID('9c6d5a77-7256-457e-9461-347b4358e352')]
s = self.execute("SELECT %s AS foo", (u,)) s = self.execute("SELECT %s AS foo", (u,))
self.failUnless(u == s) self.failUnless(u == s)
@ -86,13 +103,17 @@ class TypesExtrasTests(unittest.TestCase):
i = Inet("192.168.1.0/24") i = Inet("192.168.1.0/24")
a = psycopg2.extensions.adapt(i) a = psycopg2.extensions.adapt(i)
a.prepare(self.conn) a.prepare(self.conn)
self.assertEqual("E'192.168.1.0/24'::inet", a.getquoted()) self.assertEqual(
filter_scs(self.conn, "E'192.168.1.0/24'::inet"),
a.getquoted())
# adapts ok with unicode too # adapts ok with unicode too
i = Inet(u"192.168.1.0/24") i = Inet(u"192.168.1.0/24")
a = psycopg2.extensions.adapt(i) a = psycopg2.extensions.adapt(i)
a.prepare(self.conn) a.prepare(self.conn)
self.assertEqual("E'192.168.1.0/24'::inet", a.getquoted()) self.assertEqual(
filter_scs(self.conn, "E'192.168.1.0/24'::inet"),
a.getquoted())
def test_adapt_fail(self): def test_adapt_fail(self):
class Foo(object): pass class Foo(object): pass
@ -109,8 +130,7 @@ def skip_if_no_hstore(f):
from psycopg2.extras import HstoreAdapter from psycopg2.extras import HstoreAdapter
oids = HstoreAdapter.get_oids(self.conn) oids = HstoreAdapter.get_oids(self.conn)
if oids is None: if oids is None:
warnings.warn("hstore not available in test database: skipping test") return self.skipTest("hstore not available in test database")
return
return f(self) return f(self)
return skip_if_no_hstore_ return skip_if_no_hstore_
@ -121,8 +141,7 @@ class HstoreTestCase(unittest.TestCase):
def test_adapt_8(self): def test_adapt_8(self):
if self.conn.server_version >= 90000: if self.conn.server_version >= 90000:
warnings.warn("skipping dict adaptation with PG pre-9 syntax") return self.skipTest("skipping dict adaptation with PG pre-9 syntax")
return
from psycopg2.extras import HstoreAdapter from psycopg2.extras import HstoreAdapter
@ -136,16 +155,15 @@ class HstoreTestCase(unittest.TestCase):
ii = q[1:-1].split("||") ii = q[1:-1].split("||")
ii.sort() ii.sort()
self.assertEqual(ii[0], "(E'a' => E'1')") self.assertEqual(ii[0], filter_scs(self.conn, "(E'a' => E'1')"))
self.assertEqual(ii[1], "(E'b' => E'''')") self.assertEqual(ii[1], filter_scs(self.conn, "(E'b' => E'''')"))
self.assertEqual(ii[2], "(E'c' => NULL)") self.assertEqual(ii[2], filter_scs(self.conn, "(E'c' => NULL)"))
encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding]) encc = u'\xe0'.encode(psycopg2.extensions.encodings[self.conn.encoding])
self.assertEqual(ii[3], "(E'd' => E'%s')" % encc) self.assertEqual(ii[3], filter_scs(self.conn, "(E'd' => E'%s')" % encc))
def test_adapt_9(self): def test_adapt_9(self):
if self.conn.server_version < 90000: if self.conn.server_version < 90000:
warnings.warn("skipping dict adaptation with PG 9 syntax") return self.skipTest("skipping dict adaptation with PG 9 syntax")
return
from psycopg2.extras import HstoreAdapter from psycopg2.extras import HstoreAdapter