This commit is contained in:
Jon Dufresne 2020-07-13 23:52:28 +00:00 committed by GitHub
commit 5f93150707
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 3315 additions and 3052 deletions

View File

@ -180,23 +180,22 @@ def _get_json_oids(conn_or_curs, name='json'):
from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use # column typarray not available before PG 8.3
conn_status = conn.status typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# column typarray not available before PG 8.3 # get the oid for the hstore
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" curs.execute(
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;"
% typarray, (name,))
r = curs.fetchone()
# get the oid for the hstore # revert the status of the connection as before the command
curs.execute( if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
"SELECT t.oid, %s FROM pg_type t WHERE t.typname = %%s;" conn.rollback()
% typarray, (name,))
r = curs.fetchone()
# revert the status of the connection as before the command
if conn_status != STATUS_IN_TRANSACTION and not conn.autocommit:
conn.rollback()
if not r: if not r:
raise conn.ProgrammingError("%s data type not found" % name) raise conn.ProgrammingError("%s data type not found" % name)

View File

@ -351,25 +351,24 @@ class RangeCaster(object):
""" """
from psycopg2.extensions import STATUS_IN_TRANSACTION from psycopg2.extensions import STATUS_IN_TRANSACTION
from psycopg2.extras import _solve_conn_curs from psycopg2.extras import _solve_conn_curs
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
if conn.info.server_version < 90200:
raise ProgrammingError("range types not available in version %s"
% conn.info.server_version)
if conn.info.server_version < 90200: # Store the transaction status of the connection to revert it after use
raise ProgrammingError("range types not available in version %s" conn_status = conn.status
% conn.info.server_version)
# Store the transaction status of the connection to revert it after use # Use the correct schema
conn_status = conn.status if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# Use the correct schema # get the type oid and attributes
if '.' in name: try:
schema, tname = name.split('.', 1) curs.execute("""\
else:
tname = name
schema = 'public'
# get the type oid and attributes
try:
curs.execute("""\
select rngtypid, rngsubtype, select rngtypid, rngsubtype,
(select typarray from pg_type where oid = rngtypid) (select typarray from pg_type where oid = rngtypid)
from pg_range r from pg_range r
@ -378,17 +377,17 @@ join pg_namespace ns on ns.oid = typnamespace
where typname = %s and ns.nspname = %s; where typname = %s and ns.nspname = %s;
""", (tname, schema)) """, (tname, schema))
except ProgrammingError: except ProgrammingError:
if not conn.autocommit: if not conn.autocommit:
conn.rollback() conn.rollback()
raise raise
else: else:
rec = curs.fetchone() rec = curs.fetchone()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != STATUS_IN_TRANSACTION if (conn_status != STATUS_IN_TRANSACTION
and not conn.autocommit): and not conn.autocommit):
conn.rollback() conn.rollback()
if not rec: if not rec:
raise ProgrammingError( raise ProgrammingError(

View File

@ -26,6 +26,7 @@ and classes until a better place in the distribution is found.
# FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public # FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
# License for more details. # License for more details.
import contextlib
import os as _os import os as _os
import time as _time import time as _time
import re as _re import re as _re
@ -787,6 +788,7 @@ def wait_select(conn):
continue continue
@contextlib.contextmanager
def _solve_conn_curs(conn_or_curs): def _solve_conn_curs(conn_or_curs):
"""Return the connection and a DBAPI cursor from a connection or cursor.""" """Return the connection and a DBAPI cursor from a connection or cursor."""
if conn_or_curs is None: if conn_or_curs is None:
@ -799,7 +801,8 @@ def _solve_conn_curs(conn_or_curs):
conn = conn_or_curs conn = conn_or_curs
curs = conn.cursor(cursor_factory=_cursor) curs = conn.cursor(cursor_factory=_cursor)
return conn, curs with curs:
yield conn, curs
class HstoreAdapter(object): class HstoreAdapter(object):
@ -910,31 +913,30 @@ class HstoreAdapter(object):
def get_oids(self, conn_or_curs): def get_oids(self, conn_or_curs):
"""Return the lists of OID of the hstore and hstore[] types. """Return the lists of OID of the hstore and hstore[] types.
""" """
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use # column typarray not available before PG 8.3
conn_status = conn.status typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
# column typarray not available before PG 8.3 rv0, rv1 = [], []
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
rv0, rv1 = [], [] # get the oid for the hstore
curs.execute("""\
# get the oid for the hstore
curs.execute("""\
SELECT t.oid, %s SELECT t.oid, %s
FROM pg_type t JOIN pg_namespace ns FROM pg_type t JOIN pg_namespace ns
ON typnamespace = ns.oid ON typnamespace = ns.oid
WHERE typname = 'hstore'; WHERE typname = 'hstore';
""" % typarray) """ % typarray)
for oids in curs: for oids in curs:
rv0.append(oids[0]) rv0.append(oids[0])
rv1.append(oids[1]) rv1.append(oids[1])
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit): and not conn.autocommit):
conn.rollback() conn.rollback()
return tuple(rv0), tuple(rv1) return tuple(rv0), tuple(rv1)
@ -1089,23 +1091,22 @@ class CompositeCaster(object):
Raise `ProgrammingError` if the type is not found. Raise `ProgrammingError` if the type is not found.
""" """
conn, curs = _solve_conn_curs(conn_or_curs) with _solve_conn_curs(conn_or_curs) as (conn, curs):
# Store the transaction status of the connection to revert it after use
conn_status = conn.status
# Store the transaction status of the connection to revert it after use # Use the correct schema
conn_status = conn.status if '.' in name:
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# Use the correct schema # column typarray not available before PG 8.3
if '.' in name: typarray = conn.info.server_version >= 80300 and "typarray" or "NULL"
schema, tname = name.split('.', 1)
else:
tname = name
schema = 'public'
# column typarray not available before PG 8.3 # get the type oid and attributes
typarray = conn.info.server_version >= 80300 and "typarray" or "NULL" curs.execute("""\
# get the type oid and attributes
curs.execute("""\
SELECT t.oid, %s, attname, atttypid SELECT t.oid, %s, attname, atttypid
FROM pg_type t FROM pg_type t
JOIN pg_namespace ns ON typnamespace = ns.oid JOIN pg_namespace ns ON typnamespace = ns.oid
@ -1115,12 +1116,12 @@ WHERE typname = %%s AND nspname = %%s
ORDER BY attnum; ORDER BY attnum;
""" % typarray, (tname, schema)) """ % typarray, (tname, schema))
recs = curs.fetchall() recs = curs.fetchall()
# revert the status of the connection as before the command # revert the status of the connection as before the command
if (conn_status != _ext.STATUS_IN_TRANSACTION if (conn_status != _ext.STATUS_IN_TRANSACTION
and not conn.autocommit): and not conn.autocommit):
conn.rollback() conn.rollback()
if not recs: if not recs:
raise psycopg2.ProgrammingError( raise psycopg2.ProgrammingError(

View File

@ -59,6 +59,7 @@ static PyObject *
psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs) psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs)
{ {
PyObject *obj = NULL; PyObject *obj = NULL;
cursorObject *curs = NULL;
PyObject *rv = NULL; PyObject *rv = NULL;
PyObject *name = Py_None; PyObject *name = Py_None;
PyObject *factory = Py_None; PyObject *factory = Py_None;
@ -113,23 +114,46 @@ psyco_conn_cursor(connectionObject *self, PyObject *args, PyObject *kwargs)
goto exit; goto exit;
} }
if (0 > curs_withhold_set((cursorObject *)obj, withhold)) { /* pass ownership from obj to curs */
goto exit; curs = (cursorObject *)obj;
obj = NULL;
if (0 > curs_withhold_set(curs, withhold)) {
goto error;
} }
if (0 > curs_scrollable_set((cursorObject *)obj, scrollable)) { if (0 > curs_scrollable_set(curs, scrollable)) {
goto exit; goto error;
} }
Dprintf("psyco_conn_cursor: new cursor at %p: refcnt = " Dprintf("psyco_conn_cursor: new cursor at %p: refcnt = "
FORMAT_CODE_PY_SSIZE_T, FORMAT_CODE_PY_SSIZE_T,
obj, Py_REFCNT(obj) curs, Py_REFCNT(curs)
); );
rv = obj; /* pass ownership from curs to rv */
obj = NULL; rv = (PyObject *)curs;
curs = NULL;
goto exit;
error:
{
PyObject *error_type, *error_value, *error_traceback;
PyObject *close;
PyErr_Fetch(&error_type, &error_value, &error_traceback);
if (curs) {
close = PyObject_CallMethod((PyObject *)curs, "close", NULL);
if (close)
Py_DECREF(close);
else
PyErr_WriteUnraisable((PyObject *)curs);
}
PyErr_Restore(error_type, error_value, error_traceback);
}
exit: exit:
Py_XDECREF(obj); Py_XDECREF(obj);
Py_XDECREF(curs);
return rv; return rv;
} }
@ -1366,6 +1390,9 @@ connection_dealloc(PyObject* obj)
{ {
connectionObject *self = (connectionObject *)obj; connectionObject *self = (connectionObject *)obj;
if (PyObject_CallFinalizerFromDealloc(obj) < 0)
return;
/* Make sure to untrack the connection before calling conn_close, which may /* Make sure to untrack the connection before calling conn_close, which may
* allow a different thread to try and dealloc the connection again, * allow a different thread to try and dealloc the connection again,
* resulting in a double-free segfault (ticket #166). */ * resulting in a double-free segfault (ticket #166). */
@ -1405,6 +1432,31 @@ connection_dealloc(PyObject* obj)
Py_TYPE(obj)->tp_free(obj); Py_TYPE(obj)->tp_free(obj);
} }
#if PY_3
static void
connection_finalize(PyObject *obj)
{
connectionObject *self = (connectionObject *)obj;
#ifdef CONN_CHECK_PID
if (self->procpid == getpid())
#endif
{
if (!self->closed) {
PyObject *error_type, *error_value, *error_traceback;
/* Save the current exception, if any. */
PyErr_Fetch(&error_type, &error_value, &error_traceback);
if (PyErr_WarnFormat(PyExc_ResourceWarning, 1, "unclosed connection %R", obj))
PyErr_WriteUnraisable(obj);
/* Restore the saved exception. */
PyErr_Restore(error_type, error_value, error_traceback);
}
}
}
#endif
static int static int
connection_init(PyObject *obj, PyObject *args, PyObject *kwds) connection_init(PyObject *obj, PyObject *args, PyObject *kwds)
{ {
@ -1479,7 +1531,7 @@ PyTypeObject connectionType = {
0, /*tp_setattro*/ 0, /*tp_setattro*/
0, /*tp_as_buffer*/ 0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_GC |
Py_TPFLAGS_HAVE_WEAKREFS, Py_TPFLAGS_HAVE_WEAKREFS | Py_TPFLAGS_HAVE_FINALIZE,
/*tp_flags*/ /*tp_flags*/
connectionType_doc, /*tp_doc*/ connectionType_doc, /*tp_doc*/
(traverseproc)connection_traverse, /*tp_traverse*/ (traverseproc)connection_traverse, /*tp_traverse*/
@ -1499,4 +1551,16 @@ PyTypeObject connectionType = {
connection_init, /*tp_init*/ connection_init, /*tp_init*/
0, /*tp_alloc*/ 0, /*tp_alloc*/
connection_new, /*tp_new*/ connection_new, /*tp_new*/
0, /* tp_free */
0, /* tp_is_gc */
0, /* tp_bases */
0, /* tp_mro */
0, /* tp_cache */
0, /* tp_subclasses */
0, /* tp_weaklist */
#if PY_3
0, /* tp_del */
0, /* tp_version_tag */
connection_finalize, /* tp_finalize */
#endif
}; };

View File

@ -39,6 +39,7 @@
#include <stdlib.h> #include <stdlib.h>
#define cursor_closed(self) ((self)->closed || ((self)->conn && (self)->conn->closed))
/** DBAPI methods **/ /** DBAPI methods **/
@ -1637,7 +1638,7 @@ exit:
static PyObject * static PyObject *
curs_closed_get(cursorObject *self, void *closure) curs_closed_get(cursorObject *self, void *closure)
{ {
return PyBool_FromLong(self->closed || (self->conn && self->conn->closed)); return PyBool_FromLong(cursor_closed(self));
} }
/* extension: withhold - get or set "WITH HOLD" for named cursors */ /* extension: withhold - get or set "WITH HOLD" for named cursors */
@ -1945,6 +1946,9 @@ cursor_dealloc(PyObject* obj)
{ {
cursorObject *self = (cursorObject *)obj; cursorObject *self = (cursorObject *)obj;
if (PyObject_CallFinalizerFromDealloc(obj) < 0)
return;
PyObject_GC_UnTrack(self); PyObject_GC_UnTrack(self);
if (self->weakreflist) { if (self->weakreflist) {
@ -1965,6 +1969,28 @@ cursor_dealloc(PyObject* obj)
Py_TYPE(obj)->tp_free(obj); Py_TYPE(obj)->tp_free(obj);
} }
#if PY_3
static void
cursor_finalize(PyObject *obj)
{
cursorObject *self = (cursorObject *)obj;
if (!cursor_closed(self)) {
PyObject *error_type, *error_value, *error_traceback;
/* Save the current exception, if any. */
PyErr_Fetch(&error_type, &error_value, &error_traceback);
if (PyErr_WarnFormat(PyExc_ResourceWarning, 1,
"unclosed cursor %R for connection %R",
obj, (PyObject *)self->conn))
PyErr_WriteUnraisable(obj);
/* Restore the saved exception. */
PyErr_Restore(error_type, error_value, error_traceback);
}
}
#endif
static int static int
cursor_init(PyObject *obj, PyObject *args, PyObject *kwargs) cursor_init(PyObject *obj, PyObject *args, PyObject *kwargs)
{ {
@ -2056,7 +2082,7 @@ PyTypeObject cursorType = {
0, /*tp_setattro*/ 0, /*tp_setattro*/
0, /*tp_as_buffer*/ 0, /*tp_as_buffer*/
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_ITER | Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE | Py_TPFLAGS_HAVE_ITER |
Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_HAVE_WEAKREFS , Py_TPFLAGS_HAVE_GC | Py_TPFLAGS_HAVE_WEAKREFS | Py_TPFLAGS_HAVE_FINALIZE,
/*tp_flags*/ /*tp_flags*/
cursorType_doc, /*tp_doc*/ cursorType_doc, /*tp_doc*/
(traverseproc)cursor_traverse, /*tp_traverse*/ (traverseproc)cursor_traverse, /*tp_traverse*/
@ -2076,4 +2102,16 @@ PyTypeObject cursorType = {
cursor_init, /*tp_init*/ cursor_init, /*tp_init*/
0, /*tp_alloc*/ 0, /*tp_alloc*/
cursor_new, /*tp_new*/ cursor_new, /*tp_new*/
0, /* tp_free */
0, /* tp_is_gc */
0, /* tp_bases */
0, /* tp_mro */
0, /* tp_cache */
0, /* tp_subclasses */
0, /* tp_weaklist */
#if PY_3
0, /* tp_del */
0, /* tp_version_tag */
cursor_finalize, /* tp_finalize */
#endif
}; };

View File

@ -86,6 +86,8 @@ typedef unsigned long Py_uhash_t;
#define Bytes_ConcatAndDel PyString_ConcatAndDel #define Bytes_ConcatAndDel PyString_ConcatAndDel
#define _Bytes_Resize _PyString_Resize #define _Bytes_Resize _PyString_Resize
#define Py_TPFLAGS_HAVE_FINALIZE 0L
#define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days) #define PyDateTime_DELTA_GET_DAYS(o) (((PyDateTime_Delta*)o)->days)
#define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds) #define PyDateTime_DELTA_GET_SECONDS(o) (((PyDateTime_Delta*)o)->seconds)
#define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds) #define PyDateTime_DELTA_GET_MICROSECONDS(o) (((PyDateTime_Delta*)o)->microseconds)
@ -97,6 +99,8 @@ typedef unsigned long Py_uhash_t;
PyLong_FromUnsignedLong((unsigned long)(x)) : \ PyLong_FromUnsignedLong((unsigned long)(x)) : \
PyInt_FromLong((x))) PyInt_FromLong((x)))
#define PyObject_CallFinalizerFromDealloc(obj) 0
#endif /* PY_2 */ #endif /* PY_2 */
#if PY_3 #if PY_3

View File

@ -56,6 +56,7 @@ from . import test_sql
from . import test_transaction from . import test_transaction
from . import test_types_basic from . import test_types_basic
from . import test_types_extras from . import test_types_extras
from . import test_warnings
from . import test_with from . import test_with
if sys.version_info[:2] < (3, 6): if sys.version_info[:2] < (3, 6):
@ -101,6 +102,7 @@ def test_suite():
suite.addTest(test_transaction.test_suite()) suite.addTest(test_transaction.test_suite())
suite.addTest(test_types_basic.test_suite()) suite.addTest(test_types_basic.test_suite())
suite.addTest(test_types_extras.test_suite()) suite.addTest(test_types_extras.test_suite())
suite.addTest(test_warnings.test_suite())
suite.addTest(test_with.test_suite()) suite.addTest(test_with.test_suite())
return suite return suite

View File

@ -232,6 +232,7 @@ class DatabaseAPI20Test(unittest.TestCase):
self.failUnless(con.InternalError is drv.InternalError) self.failUnless(con.InternalError is drv.InternalError)
self.failUnless(con.ProgrammingError is drv.ProgrammingError) self.failUnless(con.ProgrammingError is drv.ProgrammingError)
self.failUnless(con.NotSupportedError is drv.NotSupportedError) self.failUnless(con.NotSupportedError is drv.NotSupportedError)
con.close()
def test_commit(self): def test_commit(self):
@ -251,6 +252,7 @@ class DatabaseAPI20Test(unittest.TestCase):
con.rollback() con.rollback()
except self.driver.NotSupportedError: except self.driver.NotSupportedError:
pass pass
con.close()
def test_cursor(self): def test_cursor(self):
con = self._connect() con = self._connect()

View File

@ -24,19 +24,22 @@ class TwoPhaseCommitTests(unittest.TestCase):
def test_xid(self): def test_xid(self):
con = self.connect() con = self.connect()
try: try:
xid = con.xid(42, "global", "bqual") try:
except self.driver.NotSupportedError: xid = con.xid(42, "global", "bqual")
self.fail("Driver does not support transaction IDs.") except self.driver.NotSupportedError:
self.fail("Driver does not support transaction IDs.")
self.assertEquals(xid[0], 42) self.assertEquals(xid[0], 42)
self.assertEquals(xid[1], "global") self.assertEquals(xid[1], "global")
self.assertEquals(xid[2], "bqual") self.assertEquals(xid[2], "bqual")
# Try some extremes for the transaction ID: # Try some extremes for the transaction ID:
xid = con.xid(0, "", "") xid = con.xid(0, "", "")
self.assertEquals(tuple(xid), (0, "", "")) self.assertEquals(tuple(xid), (0, "", ""))
xid = con.xid(0x7fffffff, "a" * 64, "b" * 64) xid = con.xid(0x7fffffff, "a" * 64, "b" * 64)
self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64)) self.assertEquals(tuple(xid), (0x7fffffff, "a" * 64, "b" * 64))
finally:
con.close()
def test_tpc_begin(self): def test_tpc_begin(self):
con = self.connect() con = self.connect()

View File

@ -61,16 +61,18 @@ class AsyncTests(ConnectingTestCase):
self.wait(self.conn) self.wait(self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
)''') )''')
self.wait(curs) self.wait(curs)
def test_connection_setup(self): def test_connection_setup(self):
cur = self.conn.cursor() cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor() sync_cur = self.sync_conn.cursor()
cur.close()
sync_cur.close()
del cur, sync_cur del cur, sync_cur
self.assert_(self.conn.async_) self.assert_(self.conn.async_)
@ -90,159 +92,156 @@ class AsyncTests(ConnectingTestCase):
self.conn.cursor, "name") self.conn.cursor, "name")
def test_async_select(self): def test_async_select(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertFalse(self.conn.isexecuting()) self.assertFalse(self.conn.isexecuting())
cur.execute("select 'a'") cur.execute("select 'a'")
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
self.wait(cur) self.wait(cur)
self.assertFalse(self.conn.isexecuting()) self.assertFalse(self.conn.isexecuting())
self.assertEquals(cur.fetchone()[0], "a") self.assertEquals(cur.fetchone()[0], "a")
@slow @slow
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_async_callproc(self): def test_async_callproc(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.callproc("pg_sleep", (0.1, )) cur.callproc("pg_sleep", (0.1, ))
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
self.wait(cur) self.wait(cur)
self.assertFalse(self.conn.isexecuting()) self.assertFalse(self.conn.isexecuting())
self.assertEquals(cur.fetchall()[0][0], '') self.assertEquals(cur.fetchall()[0][0], '')
@slow @slow
def test_async_after_async(self): def test_async_after_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur2 = self.conn.cursor() cur2 = self.conn.cursor()
del cur2 cur2.close()
del cur2
cur.execute("insert into table1 values (1)") cur.execute("insert into table1 values (1)")
# an async execute after an async one raises an exception # an async execute after an async one raises an exception
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.execute, "select * from table1") cur.execute, "select * from table1")
# same for callproc # same for callproc
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.callproc, "version") cur.callproc, "version")
# but after you've waited it should be good # but after you've waited it should be good
self.wait(cur) self.wait(cur)
cur.execute("select * from table1") cur.execute("select * from table1")
self.wait(cur) self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], 1) self.assertEquals(cur.fetchall()[0][0], 1)
cur.execute("delete from table1") cur.execute("delete from table1")
self.wait(cur) self.wait(cur)
cur.execute("select * from table1") cur.execute("select * from table1")
self.wait(cur) self.wait(cur)
self.assertEquals(cur.fetchone(), None) self.assertEquals(cur.fetchone(), None)
def test_fetch_after_async(self): def test_fetch_after_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'a'") cur.execute("select 'a'")
# a fetch after an asynchronous query should raise an error # a fetch after an asynchronous query should raise an error
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.fetchall) cur.fetchall)
# but after waiting it should work # but after waiting it should work
self.wait(cur) self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], "a") self.assertEquals(cur.fetchall()[0][0], "a")
def test_rollback_while_async(self): def test_rollback_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'a'")
cur.execute("select 'a'") # a rollback should not work in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback)
# a rollback should not work in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, self.conn.rollback)
def test_commit_while_async(self): def test_commit_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("begin")
self.wait(cur)
cur.execute("begin") cur.execute("insert into table1 values (1)")
self.wait(cur)
cur.execute("insert into table1 values (1)") # a commit should not work in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, self.conn.commit)
self.assertTrue(self.conn.isexecuting())
# a commit should not work in asynchronous mode # but a manual commit should
self.assertRaises(psycopg2.ProgrammingError, self.conn.commit) self.wait(cur)
self.assertTrue(self.conn.isexecuting()) cur.execute("commit")
self.wait(cur)
# but a manual commit should cur.execute("select * from table1")
self.wait(cur) self.wait(cur)
cur.execute("commit") self.assertEquals(cur.fetchall()[0][0], 1)
self.wait(cur)
cur.execute("select * from table1") cur.execute("delete from table1")
self.wait(cur) self.wait(cur)
self.assertEquals(cur.fetchall()[0][0], 1)
cur.execute("delete from table1") cur.execute("select * from table1")
self.wait(cur) self.wait(cur)
self.assertEquals(cur.fetchone(), None)
cur.execute("select * from table1")
self.wait(cur)
self.assertEquals(cur.fetchone(), None)
def test_set_parameters_while_async(self): def test_set_parameters_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'c'")
self.assertTrue(self.conn.isexecuting())
cur.execute("select 'c'") # getting transaction status works
self.assertTrue(self.conn.isexecuting()) self.assertEquals(self.conn.info.transaction_status,
ext.TRANSACTION_STATUS_ACTIVE)
self.assertTrue(self.conn.isexecuting())
# getting transaction status works # setting connection encoding should fail
self.assertEquals(self.conn.info.transaction_status, self.assertRaises(psycopg2.ProgrammingError,
ext.TRANSACTION_STATUS_ACTIVE) self.conn.set_client_encoding, "LATIN1")
self.assertTrue(self.conn.isexecuting())
# setting connection encoding should fail # same for transaction isolation
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
self.conn.set_client_encoding, "LATIN1") self.conn.set_isolation_level, 1)
# same for transaction isolation
self.assertRaises(psycopg2.ProgrammingError,
self.conn.set_isolation_level, 1)
def test_reset_while_async(self): def test_reset_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'c'") cur.execute("select 'c'")
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
# a reset should fail # a reset should fail
self.assertRaises(psycopg2.ProgrammingError, self.conn.reset) self.assertRaises(psycopg2.ProgrammingError, self.conn.reset)
def test_async_iter(self): def test_async_iter(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("begin")
self.wait(cur)
cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
self.wait(cur)
cur.execute("select id from table1 order by id")
cur.execute("begin") # iteration fails if a query is underway
self.wait(cur) self.assertRaises(psycopg2.ProgrammingError, list, cur)
cur.execute("""
insert into table1 values (1);
insert into table1 values (2);
insert into table1 values (3);
""")
self.wait(cur)
cur.execute("select id from table1 order by id")
# iteration fails if a query is underway # but after it's done it should work
self.assertRaises(psycopg2.ProgrammingError, list, cur) self.wait(cur)
self.assertEquals(list(cur), [(1, ), (2, ), (3, )])
# but after it's done it should work self.assertFalse(self.conn.isexecuting())
self.wait(cur)
self.assertEquals(list(cur), [(1, ), (2, ), (3, )])
self.assertFalse(self.conn.isexecuting())
def test_copy_while_async(self): def test_copy_while_async(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 'a'") cur.execute("select 'a'")
# copy should fail # copy should fail
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.copy_from, cur.copy_from,
StringIO("1\n3\n5\n\\.\n"), "table1") StringIO("1\n3\n5\n\\.\n"), "table1")
def test_lobject_while_async(self): def test_lobject_while_async(self):
# large objects should be prohibited # large objects should be prohibited
@ -250,68 +249,68 @@ class AsyncTests(ConnectingTestCase):
self.conn.lobject) self.conn.lobject)
def test_async_executemany(self): def test_async_executemany(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises( self.assertRaises(
psycopg2.ProgrammingError, psycopg2.ProgrammingError,
cur.executemany, "insert into table1 values (%s)", [1, 2, 3]) cur.executemany, "insert into table1 values (%s)", [1, 2, 3])
def test_async_scroll(self): def test_async_scroll(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
insert into table1 values (1); insert into table1 values (1);
insert into table1 values (2); insert into table1 values (2);
insert into table1 values (3); insert into table1 values (3);
""") """)
self.wait(cur) self.wait(cur)
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
# scroll should fail if a query is underway # scroll should fail if a query is underway
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1) self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 1)
self.assertTrue(self.conn.isexecuting()) self.assertTrue(self.conn.isexecuting())
# but after it's done it should work # but after it's done it should work
self.wait(cur) self.wait(cur)
cur.scroll(1) cur.scroll(1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )]) self.assertEquals(cur.fetchall(), [(2, ), (3, )])
cur = self.conn.cursor() with self.conn.cursor() as cur2:
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
self.wait(cur) self.wait(cur)
cur2 = self.conn.cursor() with self.conn.cursor() as cur2:
self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1) self.assertRaises(psycopg2.ProgrammingError, cur2.scroll, 1)
self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4) self.assertRaises(psycopg2.ProgrammingError, cur.scroll, 4)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
self.wait(cur) self.wait(cur)
cur.scroll(2) cur.scroll(2)
cur.scroll(-1) cur.scroll(-1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )]) self.assertEquals(cur.fetchall(), [(2, ), (3, )])
def test_scroll(self): def test_scroll(self):
cur = self.sync_conn.cursor() with self.sync_conn.cursor() as cur:
cur.execute("create table table1 (id int)") cur.execute("create table table1 (id int)")
cur.execute(""" cur.execute("""
insert into table1 values (1); insert into table1 values (1);
insert into table1 values (2); insert into table1 values (2);
insert into table1 values (3); insert into table1 values (3);
""") """)
cur.execute("select id from table1 order by id") cur.execute("select id from table1 order by id")
cur.scroll(2) cur.scroll(2)
cur.scroll(-1) cur.scroll(-1)
self.assertEquals(cur.fetchall(), [(2, ), (3, )]) self.assertEquals(cur.fetchall(), [(2, ), (3, )])
def test_async_dont_read_all(self): def test_async_dont_read_all(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select repeat('a', 10000); select repeat('b', 10000)") cur.execute("select repeat('a', 10000); select repeat('b', 10000)")
# fetch the result # fetch the result
self.wait(cur) self.wait(cur)
# it should be the result of the second query # it should be the result of the second query
self.assertEquals(cur.fetchone()[0], "b" * 10000) self.assertEquals(cur.fetchone()[0], "b" * 10000)
def test_async_subclass(self): def test_async_subclass(self):
class MyConn(ext.connection): class MyConn(ext.connection):
@ -326,15 +325,15 @@ class AsyncTests(ConnectingTestCase):
@slow @slow
def test_flush_on_write(self): def test_flush_on_write(self):
# a very large query requires a flush loop to be sent to the backend # a very large query requires a flush loop to be sent to the backend
curs = self.conn.cursor() with self.conn.cursor() as curs:
for mb in 1, 5, 10, 20, 50: for mb in 1, 5, 10, 20, 50:
size = mb * 1024 * 1024 size = mb * 1024 * 1024
stub = PollableStub(self.conn) stub = PollableStub(self.conn)
curs.execute("select %s;", ('x' * size,)) curs.execute("select %s;", ('x' * size,))
self.wait(stub) self.wait(stub)
self.assertEqual(size, len(curs.fetchone()[0])) self.assertEqual(size, len(curs.fetchone()[0]))
if stub.polls.count(ext.POLL_WRITE) > 1: if stub.polls.count(ext.POLL_WRITE) > 1:
return return
# This is more a testing glitch than an error: it happens # This is more a testing glitch than an error: it happens
# on high load on linux: probably because the kernel has more # on high load on linux: probably because the kernel has more
@ -343,159 +342,157 @@ class AsyncTests(ConnectingTestCase):
warnings.warn("sending a large query didn't trigger block on write.") warnings.warn("sending a large query didn't trigger block on write.")
def test_sync_poll(self): def test_sync_poll(self):
cur = self.sync_conn.cursor() with self.sync_conn.cursor() as cur:
cur.execute("select 1") cur.execute("select 1")
# polling with a sync query works # polling with a sync query works
cur.connection.poll() cur.connection.poll()
self.assertEquals(cur.fetchone()[0], 1) self.assertEquals(cur.fetchone()[0], 1)
@slow @slow
def test_notify(self): def test_notify(self):
cur = self.conn.cursor() with self.conn.cursor() as cur, self.sync_conn.cursor() as sync_cur:
sync_cur = self.sync_conn.cursor() sync_cur.execute("listen test_notify")
self.sync_conn.commit()
cur.execute("notify test_notify")
self.wait(cur)
sync_cur.execute("listen test_notify") self.assertEquals(self.sync_conn.notifies, [])
self.sync_conn.commit()
cur.execute("notify test_notify")
self.wait(cur)
self.assertEquals(self.sync_conn.notifies, []) pid = self.conn.info.backend_pid
for _ in range(5):
pid = self.conn.info.backend_pid self.wait(self.sync_conn)
for _ in range(5): if not self.sync_conn.notifies:
self.wait(self.sync_conn) time.sleep(0.5)
if not self.sync_conn.notifies: continue
time.sleep(0.5) self.assertEquals(len(self.sync_conn.notifies), 1)
continue self.assertEquals(self.sync_conn.notifies.pop(),
self.assertEquals(len(self.sync_conn.notifies), 1) (pid, "test_notify"))
self.assertEquals(self.sync_conn.notifies.pop(), return
(pid, "test_notify"))
return
self.fail("No NOTIFY in 2.5 seconds") self.fail("No NOTIFY in 2.5 seconds")
def test_async_fetch_wrong_cursor(self): def test_async_fetch_wrong_cursor(self):
cur1 = self.conn.cursor() with self.conn.cursor() as cur1, self.conn.cursor() as cur2:
cur2 = self.conn.cursor() cur1.execute("select 1")
cur1.execute("select 1")
self.wait(cur1) self.wait(cur1)
self.assertFalse(self.conn.isexecuting()) self.assertFalse(self.conn.isexecuting())
# fetching from a cursor with no results is an error # fetching from a cursor with no results is an error
self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone) self.assertRaises(psycopg2.ProgrammingError, cur2.fetchone)
# fetching from the correct cursor works # fetching from the correct cursor works
self.assertEquals(cur1.fetchone()[0], 1) self.assertEquals(cur1.fetchone()[0], 1)
def test_error(self): def test_error(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("insert into table1 values (%s)", (1, )) cur.execute("insert into table1 values (%s)", (1, ))
self.wait(cur) self.wait(cur)
cur.execute("insert into table1 values (%s)", (1, )) cur.execute("insert into table1 values (%s)", (1, ))
# this should fail # this should fail
self.assertRaises(psycopg2.IntegrityError, self.wait, cur) self.assertRaises(psycopg2.IntegrityError, self.wait, cur)
cur.execute("insert into table1 values (%s); " cur.execute("insert into table1 values (%s); "
"insert into table1 values (%s)", (2, 2)) "insert into table1 values (%s)", (2, 2))
# this should fail as well # this should fail as well
self.assertRaises(psycopg2.IntegrityError, self.wait, cur) self.assertRaises(psycopg2.IntegrityError, self.wait, cur)
# but this should work # but this should work
cur.execute("insert into table1 values (%s)", (2, )) cur.execute("insert into table1 values (%s)", (2, ))
self.wait(cur) self.wait(cur)
# and the cursor should be usable afterwards # and the cursor should be usable afterwards
cur.execute("insert into table1 values (%s)", (3, )) cur.execute("insert into table1 values (%s)", (3, ))
self.wait(cur) self.wait(cur)
cur.execute("select * from table1 order by id") cur.execute("select * from table1 order by id")
self.wait(cur) self.wait(cur)
self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )]) self.assertEquals(cur.fetchall(), [(1, ), (2, ), (3, )])
cur.execute("delete from table1") cur.execute("delete from table1")
self.wait(cur) self.wait(cur)
def test_stop_on_first_error(self): def test_stop_on_first_error(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 1; select x; select 1/0; select 2") cur.execute("select 1; select x; select 1/0; select 2")
self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur) self.assertRaises(psycopg2.errors.UndefinedColumn, self.wait, cur)
cur.execute("select 1") cur.execute("select 1")
self.wait(cur) self.wait(cur)
self.assertEqual(cur.fetchone(), (1,)) self.assertEqual(cur.fetchone(), (1,))
def test_error_two_cursors(self): def test_error_two_cursors(self):
cur = self.conn.cursor() with self.conn.cursor() as cur, self.conn.cursor() as cur2:
cur2 = self.conn.cursor() cur.execute("select * from no_such_table")
cur.execute("select * from no_such_table") self.assertRaises(psycopg2.ProgrammingError, self.wait, cur)
self.assertRaises(psycopg2.ProgrammingError, self.wait, cur) cur2.execute("select 1")
cur2.execute("select 1") self.wait(cur2)
self.wait(cur2) self.assertEquals(cur2.fetchone()[0], 1)
self.assertEquals(cur2.fetchone()[0], 1)
def test_notices(self): def test_notices(self):
del self.conn.notices[:] del self.conn.notices[:]
cur = self.conn.cursor() with self.conn.cursor() as cur:
if self.conn.info.server_version >= 90300: if self.conn.info.server_version >= 90300:
cur.execute("set client_min_messages=debug1") cur.execute("set client_min_messages=debug1")
self.wait(cur)
cur.execute("create temp table chatty (id serial primary key);")
self.wait(cur) self.wait(cur)
cur.execute("create temp table chatty (id serial primary key);") self.assertEqual("CREATE TABLE", cur.statusmessage)
self.wait(cur) self.assert_(self.conn.notices)
self.assertEqual("CREATE TABLE", cur.statusmessage)
self.assert_(self.conn.notices)
def test_async_cursor_gone(self): def test_async_cursor_gone(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 42;") cur.execute("select 42;")
del cur del cur
gc.collect() gc.collect()
self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn) self.assertRaises(psycopg2.InterfaceError, self.wait, self.conn)
# The connection is still usable # The connection is still usable
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select 42;") cur.execute("select 42;")
self.wait(self.conn) self.wait(self.conn)
self.assertEqual(cur.fetchone(), (42,)) self.assertEqual(cur.fetchone(), (42,))
def test_async_connection_error_message(self): def test_async_connection_error_message(self):
cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True)
try: try:
cnn = psycopg2.connect('dbname=thisdatabasedoesntexist', async_=True)
self.wait(cnn) self.wait(cnn)
except psycopg2.Error as e: except psycopg2.Error as e:
self.assertNotEqual(str(e), "asynchronous connection failed", self.assertNotEqual(str(e), "asynchronous connection failed",
"connection error reason lost") "connection error reason lost")
else: else:
self.fail("no exception raised") self.fail("no exception raised")
finally:
cnn.close()
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_copy_no_hang(self): def test_copy_no_hang(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("copy (select 1) to stdout") cur.execute("copy (select 1) to stdout")
self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn) self.assertRaises(psycopg2.ProgrammingError, self.wait, self.conn)
@slow @slow
@skip_before_postgres(9, 0) @skip_before_postgres(9, 0)
def test_non_block_after_notification(self): def test_non_block_after_notification(self):
from select import select from select import select
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
select 1; select 1;
do $$ do $$
begin begin
raise notice 'hello'; raise notice 'hello';
end end
$$ language plpgsql; $$ language plpgsql;
select pg_sleep(1); select pg_sleep(1);
""") """)
polls = 0 polls = 0
while True: while True:
state = self.conn.poll() state = self.conn.poll()
if state == psycopg2.extensions.POLL_OK: if state == psycopg2.extensions.POLL_OK:
break break
elif state == psycopg2.extensions.POLL_READ: elif state == psycopg2.extensions.POLL_READ:
select([self.conn], [], [], 0.1) select([self.conn], [], [], 0.1)
elif state == psycopg2.extensions.POLL_WRITE: elif state == psycopg2.extensions.POLL_WRITE:
select([], [self.conn], [], 0.1) select([], [self.conn], [], 0.1)
else: else:
raise Exception("Unexpected result from poll: %r", state) raise Exception("Unexpected result from poll: %r", state)
polls += 1 polls += 1
self.assert_(polls >= 8, polls) self.assert_(polls >= 8, polls)
def test_poll_noop(self): def test_poll_noop(self):
self.conn.poll() self.conn.poll()

View File

@ -47,16 +47,18 @@ class AsyncTests(ConnectingTestCase):
self.wait(self.conn) self.wait(self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
)''') )''')
self.wait(curs) self.wait(curs)
def test_connection_setup(self): def test_connection_setup(self):
cur = self.conn.cursor() cur = self.conn.cursor()
sync_cur = self.sync_conn.cursor() sync_cur = self.sync_conn.cursor()
cur.close()
sync_cur.close()
del cur, sync_cur del cur, sync_cur
self.assert_(self.conn.async) self.assert_(self.conn.async)
@ -89,17 +91,19 @@ class AsyncTests(ConnectingTestCase):
"connection error reason lost") "connection error reason lost")
else: else:
self.fail("no exception raised") self.fail("no exception raised")
finally:
cnn.close()
class CancelTests(ConnectingTestCase): class CancelTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(''' cur.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
)''') )''')
self.conn.commit() self.conn.commit()
@slow @slow
@ -108,16 +112,17 @@ class CancelTests(ConnectingTestCase):
async_conn = psycopg2.connect(dsn, async=True) async_conn = psycopg2.connect(dsn, async=True)
self.assertRaises(psycopg2.OperationalError, async_conn.cancel) self.assertRaises(psycopg2.OperationalError, async_conn.cancel)
extras.wait_select(async_conn) extras.wait_select(async_conn)
cur = async_conn.cursor() with async_conn.cursor() as cur:
cur.execute("select pg_sleep(10)") cur.execute("select pg_sleep(10)")
time.sleep(1) time.sleep(1)
self.assertTrue(async_conn.isexecuting()) self.assertTrue(async_conn.isexecuting())
async_conn.cancel() async_conn.cancel()
self.assertRaises(psycopg2.extensions.QueryCanceledError, self.assertRaises(psycopg2.extensions.QueryCanceledError,
extras.wait_select, async_conn) extras.wait_select, async_conn)
cur.execute("select 1") cur.execute("select 1")
extras.wait_select(async_conn) extras.wait_select(async_conn)
self.assertEqual(cur.fetchall(), [(1, )]) self.assertEqual(cur.fetchall(), [(1, )])
async_conn.close()
def test_async_connection_cancel(self): def test_async_connection_cancel(self):
async_conn = psycopg2.connect(dsn, async=True) async_conn = psycopg2.connect(dsn, async=True)
@ -180,41 +185,40 @@ class AsyncReplicationTest(ReplicationTestCase):
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding')
self.wait(cur)
self.create_replication_slot(cur, output_plugin='test_decoding') cur.start_replication(self.slot)
self.wait(cur) self.wait(cur)
cur.start_replication(self.slot) self.make_replication_events()
self.wait(cur)
self.make_replication_events() self.msg_count = 0
self.msg_count = 0 def consume(msg):
# just check the methods
"%s: %s" % (cur.io_timestamp, repr(msg))
"%s: %s" % (cur.feedback_timestamp, repr(msg))
def consume(msg): self.msg_count += 1
# just check the methods if self.msg_count > 3:
"%s: %s" % (cur.io_timestamp, repr(msg)) cur.send_feedback(reply=True)
"%s: %s" % (cur.feedback_timestamp, repr(msg)) raise StopReplication()
self.msg_count += 1 cur.send_feedback(flush_lsn=msg.data_start)
if self.msg_count > 3:
cur.send_feedback(reply=True)
raise StopReplication()
cur.send_feedback(flush_lsn=msg.data_start) # cannot be used in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume)
# cannot be used in asynchronous mode def process_stream():
self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) while True:
msg = cur.read_message()
def process_stream(): if msg:
while True: consume(msg)
msg = cur.read_message() else:
if msg: select([cur], [], [])
consume(msg) self.assertRaises(StopReplication, process_stream)
else:
select([cur], [], [])
self.assertRaises(StopReplication, process_stream)
def test_suite(): def test_suite():

View File

@ -39,9 +39,9 @@ class StolenReferenceTestCase(ConnectingTestCase):
return 42 return 42
UUID = psycopg2.extensions.new_type((2950,), "UUID", fish) UUID = psycopg2.extensions.new_type((2950,), "UUID", fish)
psycopg2.extensions.register_type(UUID, self.conn) psycopg2.extensions.register_type(UUID, self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid") curs.execute("select 'b5219e01-19ab-4994-b71e-149225dc51e4'::uuid")
curs.fetchone() curs.fetchone()
def test_suite(): def test_suite():

View File

@ -41,11 +41,11 @@ class CancelTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(''' cur.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
)''') )''')
self.conn.commit() self.conn.commit()
def test_empty_cancel(self): def test_empty_cancel(self):
@ -57,25 +57,25 @@ class CancelTests(ConnectingTestCase):
errors = [] errors = []
def neverending(conn): def neverending(conn):
cur = conn.cursor() with conn.cursor() as cur:
try: try:
self.assertRaises(psycopg2.extensions.QueryCanceledError, self.assertRaises(psycopg2.extensions.QueryCanceledError,
cur.execute, "select pg_sleep(60)") cur.execute, "select pg_sleep(60)")
# make sure the connection still works # make sure the connection still works
conn.rollback() conn.rollback()
cur.execute("select 1") cur.execute("select 1")
self.assertEqual(cur.fetchall(), [(1, )]) self.assertEqual(cur.fetchall(), [(1, )])
except Exception as e: except Exception as e:
errors.append(e) errors.append(e)
raise raise
def canceller(conn): def canceller(conn):
cur = conn.cursor() with conn.cursor() as cur:
try: try:
conn.cancel() conn.cancel()
except Exception as e: except Exception as e:
errors.append(e) errors.append(e)
raise raise
del cur del cur
thread1 = threading.Thread(target=neverending, args=(self.conn, )) thread1 = threading.Thread(target=neverending, args=(self.conn, ))
@ -95,16 +95,17 @@ class CancelTests(ConnectingTestCase):
async_conn = psycopg2.connect(dsn, async_=True) async_conn = psycopg2.connect(dsn, async_=True)
self.assertRaises(psycopg2.OperationalError, async_conn.cancel) self.assertRaises(psycopg2.OperationalError, async_conn.cancel)
extras.wait_select(async_conn) extras.wait_select(async_conn)
cur = async_conn.cursor() with async_conn.cursor() as cur:
cur.execute("select pg_sleep(10)") cur.execute("select pg_sleep(10)")
time.sleep(1) time.sleep(1)
self.assertTrue(async_conn.isexecuting()) self.assertTrue(async_conn.isexecuting())
async_conn.cancel() async_conn.cancel()
self.assertRaises(psycopg2.extensions.QueryCanceledError, self.assertRaises(psycopg2.extensions.QueryCanceledError,
extras.wait_select, async_conn) extras.wait_select, async_conn)
cur.execute("select 1") cur.execute("select 1")
extras.wait_select(async_conn) extras.wait_select(async_conn)
self.assertEqual(cur.fetchall(), [(1, )]) self.assertEqual(cur.fetchall(), [(1, )])
async_conn.close()
def test_async_connection_cancel(self): def test_async_connection_cancel(self):
async_conn = psycopg2.connect(dsn, async_=True) async_conn = psycopg2.connect(dsn, async_=True)

File diff suppressed because it is too large Load Diff

View File

@ -66,12 +66,12 @@ class CopyTests(ConnectingTestCase):
self._create_temp_table() self._create_temp_table()
def _create_temp_table(self): def _create_temp_table(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE tcopy ( CREATE TEMPORARY TABLE tcopy (
id serial PRIMARY KEY, id serial PRIMARY KEY,
data text data text
)''') )''')
@slow @slow
def test_copy_from(self): def test_copy_from(self):
@ -92,31 +92,31 @@ class CopyTests(ConnectingTestCase):
curs.close() curs.close()
def test_copy_from_cols(self): def test_copy_from_cols(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
f = StringIO() f = StringIO()
for i in range(10): for i in range(10):
f.write("%s\n" % (i,)) f.write("%s\n" % (i,))
f.seek(0) f.seek(0)
curs.copy_from(MinimalRead(f), "tcopy", columns=['id']) curs.copy_from(MinimalRead(f), "tcopy", columns=['id'])
curs.execute("select * from tcopy order by id") curs.execute("select * from tcopy order by id")
self.assertEqual([(i, None) for i in range(10)], curs.fetchall()) self.assertEqual([(i, None) for i in range(10)], curs.fetchall())
def test_copy_from_cols_err(self): def test_copy_from_cols_err(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
f = StringIO() f = StringIO()
for i in range(10): for i in range(10):
f.write("%s\n" % (i,)) f.write("%s\n" % (i,))
f.seek(0) f.seek(0)
def cols(): def cols():
raise ZeroDivisionError() raise ZeroDivisionError()
yield 'id' yield 'id'
self.assertRaises(ZeroDivisionError, self.assertRaises(ZeroDivisionError,
curs.copy_from, MinimalRead(f), "tcopy", columns=cols()) curs.copy_from, MinimalRead(f), "tcopy", columns=cols())
@slow @slow
def test_copy_to(self): def test_copy_to(self):
@ -140,14 +140,14 @@ class CopyTests(ConnectingTestCase):
+ list(range(160, 256))).decode('latin1') + list(range(160, 256))).decode('latin1')
about = abin.replace('\\', '\\\\') about = abin.replace('\\', '\\\\')
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('insert into tcopy values (%s, %s)', curs.execute('insert into tcopy values (%s, %s)',
(42, abin)) (42, abin))
f = io.StringIO() f = io.StringIO()
curs.copy_to(f, 'tcopy', columns=('data',)) curs.copy_to(f, 'tcopy', columns=('data',))
f.seek(0) f.seek(0)
self.assertEqual(f.readline().rstrip(), about) self.assertEqual(f.readline().rstrip(), about)
def test_copy_bytes(self): def test_copy_bytes(self):
self.conn.set_client_encoding('latin1') self.conn.set_client_encoding('latin1')
@ -161,14 +161,14 @@ class CopyTests(ConnectingTestCase):
+ list(range(160, 255))).decode('latin1') + list(range(160, 255))).decode('latin1')
about = abin.replace('\\', '\\\\').encode('latin1') about = abin.replace('\\', '\\\\').encode('latin1')
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('insert into tcopy values (%s, %s)', curs.execute('insert into tcopy values (%s, %s)',
(42, abin)) (42, abin))
f = io.BytesIO() f = io.BytesIO()
curs.copy_to(f, 'tcopy', columns=('data',)) curs.copy_to(f, 'tcopy', columns=('data',))
f.seek(0) f.seek(0)
self.assertEqual(f.readline().rstrip(), about) self.assertEqual(f.readline().rstrip(), about)
def test_copy_expert_textiobase(self): def test_copy_expert_textiobase(self):
self.conn.set_client_encoding('latin1') self.conn.set_client_encoding('latin1')
@ -188,35 +188,35 @@ class CopyTests(ConnectingTestCase):
f.write(about) f.write(about)
f.seek(0) f.seek(0)
curs = self.conn.cursor() with self.conn.cursor() as curs:
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.UNICODE, curs) psycopg2.extensions.UNICODE, curs)
curs.copy_expert('COPY tcopy (data) FROM STDIN', f) curs.copy_expert('COPY tcopy (data) FROM STDIN', f)
curs.execute("select data from tcopy;") curs.execute("select data from tcopy;")
self.assertEqual(curs.fetchone()[0], abin) self.assertEqual(curs.fetchone()[0], abin)
f = io.StringIO() f = io.StringIO()
curs.copy_expert('COPY tcopy (data) TO STDOUT', f) curs.copy_expert('COPY tcopy (data) TO STDOUT', f)
f.seek(0) f.seek(0)
self.assertEqual(f.readline().rstrip(), about) self.assertEqual(f.readline().rstrip(), about)
# same tests with setting size # same tests with setting size
f = io.StringIO() f = io.StringIO()
f.write(about) f.write(about)
f.seek(0) f.seek(0)
exp_size = 123 exp_size = 123
# hack here to leave file as is, only check size when reading # hack here to leave file as is, only check size when reading
real_read = f.read real_read = f.read
def read(_size, f=f, exp_size=exp_size): def read(_size, f=f, exp_size=exp_size):
self.assertEqual(_size, exp_size) self.assertEqual(_size, exp_size)
return real_read(_size) return real_read(_size)
f.read = read f.read = read
curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size) curs.copy_expert('COPY tcopy (data) FROM STDIN', f, size=exp_size)
curs.execute("select data from tcopy;") curs.execute("select data from tcopy;")
self.assertEqual(curs.fetchone()[0], abin) self.assertEqual(curs.fetchone()[0], abin)
def _copy_from(self, curs, nrecs, srec, copykw): def _copy_from(self, curs, nrecs, srec, copykw):
f = StringIO() f = StringIO()
@ -254,56 +254,54 @@ class CopyTests(ConnectingTestCase):
pass pass
f = Whatever() f = Whatever()
curs = self.conn.cursor() with self.conn.cursor() as curs:
self.assertRaises(TypeError, self.assertRaises(TypeError,
curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f) curs.copy_expert, 'COPY tcopy (data) FROM STDIN', f)
def test_copy_no_column_limit(self): def test_copy_no_column_limit(self):
cols = ["c%050d" % i for i in range(200)] cols = ["c%050d" % i for i in range(200)]
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join( curs.execute('CREATE TEMPORARY TABLE manycols (%s)' % ',\n'.join(
["%s int" % c for c in cols])) ["%s int" % c for c in cols]))
curs.execute("INSERT INTO manycols DEFAULT VALUES") curs.execute("INSERT INTO manycols DEFAULT VALUES")
f = StringIO() f = StringIO()
curs.copy_to(f, "manycols", columns=cols) curs.copy_to(f, "manycols", columns=cols)
f.seek(0) f.seek(0)
self.assertEqual(f.read().split(), ['\\N'] * len(cols)) self.assertEqual(f.read().split(), ['\\N'] * len(cols))
f.seek(0) f.seek(0)
curs.copy_from(f, "manycols", columns=cols) curs.copy_from(f, "manycols", columns=cols)
curs.execute("select count(*) from manycols;") curs.execute("select count(*) from manycols;")
self.assertEqual(curs.fetchone()[0], 2) self.assertEqual(curs.fetchone()[0], 2)
@skip_before_postgres(8, 2) # they don't send the count @skip_before_postgres(8, 2) # they don't send the count
def test_copy_rowcount(self): def test_copy_rowcount(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data'])
self.assertEqual(curs.rowcount, 3)
curs.copy_from(StringIO('aaa\nbbb\nccc\n'), 'tcopy', columns=['data']) curs.copy_expert(
self.assertEqual(curs.rowcount, 3) "copy tcopy (data) from stdin",
StringIO('ddd\neee\n'))
self.assertEqual(curs.rowcount, 2)
curs.copy_expert( curs.copy_to(StringIO(), "tcopy")
"copy tcopy (data) from stdin", self.assertEqual(curs.rowcount, 5)
StringIO('ddd\neee\n'))
self.assertEqual(curs.rowcount, 2)
curs.copy_to(StringIO(), "tcopy") curs.execute("insert into tcopy (data) values ('fff')")
self.assertEqual(curs.rowcount, 5) curs.copy_expert("copy tcopy to stdout", StringIO())
self.assertEqual(curs.rowcount, 6)
curs.execute("insert into tcopy (data) values ('fff')")
curs.copy_expert("copy tcopy to stdout", StringIO())
self.assertEqual(curs.rowcount, 6)
def test_copy_rowcount_error(self): def test_copy_rowcount_error(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("insert into tcopy (data) values ('fff')")
self.assertEqual(curs.rowcount, 1)
curs.execute("insert into tcopy (data) values ('fff')") self.assertRaises(psycopg2.DataError,
self.assertEqual(curs.rowcount, 1) curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy')
self.assertEqual(curs.rowcount, -1)
self.assertRaises(psycopg2.DataError,
curs.copy_from, StringIO('aaa\nbbb\nccc\n'), 'tcopy')
self.assertEqual(curs.rowcount, -1)
@slow @slow
def test_copy_from_segfault(self): def test_copy_from_segfault(self):
@ -317,6 +315,7 @@ try:
curs.execute("copy copy_segf from stdin") curs.execute("copy copy_segf from stdin")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
pass pass
curs.close()
conn.close() conn.close()
""" % {'dsn': dsn}) """ % {'dsn': dsn})
@ -336,6 +335,7 @@ try:
curs.execute("copy copy_segf to stdout") curs.execute("copy copy_segf to stdout")
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
pass pass
curs.close()
conn.close() conn.close()
""" % {'dsn': dsn}) """ % {'dsn': dsn})
@ -351,24 +351,24 @@ conn.close()
def readline(self): def readline(self):
return 1 / 0 return 1 / 0
curs = self.conn.cursor() with self.conn.cursor() as curs:
# It seems we cannot do this, but now at least we propagate the error # It seems we cannot do this, but now at least we propagate the error
# self.assertRaises(ZeroDivisionError, # self.assertRaises(ZeroDivisionError,
# curs.copy_from, BrokenRead(), "tcopy") # curs.copy_from, BrokenRead(), "tcopy")
try: try:
curs.copy_from(BrokenRead(), "tcopy") curs.copy_from(BrokenRead(), "tcopy")
except Exception as e: except Exception as e:
self.assert_('ZeroDivisionError' in str(e)) self.assert_('ZeroDivisionError' in str(e))
def test_copy_to_propagate_error(self): def test_copy_to_propagate_error(self):
class BrokenWrite(TextIOBase): class BrokenWrite(TextIOBase):
def write(self, data): def write(self, data):
return 1 / 0 return 1 / 0
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("insert into tcopy values (10, 'hi')") curs.execute("insert into tcopy values (10, 'hi')")
self.assertRaises(ZeroDivisionError, self.assertRaises(ZeroDivisionError,
curs.copy_to, BrokenWrite(), "tcopy") curs.copy_to, BrokenWrite(), "tcopy")
def test_suite(): def test_suite():

View File

@ -51,10 +51,10 @@ class CursorTests(ConnectingTestCase):
self.assert_(cur.closed) self.assert_(cur.closed)
def test_empty_query(self): def test_empty_query(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, cur.execute, "") self.assertRaises(psycopg2.ProgrammingError, cur.execute, "")
self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ") self.assertRaises(psycopg2.ProgrammingError, cur.execute, " ")
self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";") self.assertRaises(psycopg2.ProgrammingError, cur.execute, ";")
def test_executemany_propagate_exceptions(self): def test_executemany_propagate_exceptions(self):
conn = self.conn conn = self.conn
@ -70,58 +70,57 @@ class CursorTests(ConnectingTestCase):
def test_mogrify_unicode(self): def test_mogrify_unicode(self):
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
# test consistency between execute and mogrify.
# test consistency between execute and mogrify. # unicode query containing only ascii data
cur.execute(u"SELECT 'foo';")
self.assertEqual('foo', cur.fetchone()[0])
self.assertEqual(b"SELECT 'foo';", cur.mogrify(u"SELECT 'foo';"))
# unicode query containing only ascii data conn.set_client_encoding('UTF8')
cur.execute(u"SELECT 'foo';") snowman = u"\u2603"
self.assertEqual('foo', cur.fetchone()[0])
self.assertEqual(b"SELECT 'foo';", cur.mogrify(u"SELECT 'foo';"))
conn.set_client_encoding('UTF8') def b(s):
snowman = u"\u2603" if isinstance(s, text_type):
return s.encode('utf8')
else:
return s
def b(s): # unicode query with non-ascii data
if isinstance(s, text_type): cur.execute(u"SELECT '%s';" % snowman)
return s.encode('utf8') self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0]))
else: self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'),
return s cur.mogrify(u"SELECT '%s';" % snowman))
# unicode query with non-ascii data # unicode args
cur.execute(u"SELECT '%s';" % snowman) cur.execute("SELECT %s;", (snowman,))
self.assertEqual(snowman.encode('utf8'), b(cur.fetchone()[0])) self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'),
cur.mogrify(u"SELECT '%s';" % snowman)) cur.mogrify("SELECT %s;", (snowman,)))
# unicode args # unicode query and args
cur.execute("SELECT %s;", (snowman,)) cur.execute(u"SELECT %s;", (snowman,))
self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0])) self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'), self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'),
cur.mogrify("SELECT %s;", (snowman,))) cur.mogrify(u"SELECT %s;", (snowman,)))
# unicode query and args
cur.execute(u"SELECT %s;", (snowman,))
self.assertEqual(snowman.encode("utf-8"), b(cur.fetchone()[0]))
self.assertQuotedEqual(("SELECT '%s';" % snowman).encode('utf8'),
cur.mogrify(u"SELECT %s;", (snowman,)))
def test_mogrify_decimal_explodes(self): def test_mogrify_decimal_explodes(self):
conn = self.conn conn = self.conn
cur = conn.cursor() with conn.cursor() as cur:
self.assertEqual(b'SELECT 10.3;', self.assertEqual(b'SELECT 10.3;',
cur.mogrify("SELECT %s;", (Decimal("10.3"),))) cur.mogrify("SELECT %s;", (Decimal("10.3"),)))
@skip_if_no_getrefcount @skip_if_no_getrefcount
def test_mogrify_leak_on_multiple_reference(self): def test_mogrify_leak_on_multiple_reference(self):
# issue #81: reference leak when a parameter value is referenced # issue #81: reference leak when a parameter value is referenced
# more than once from a dict. # more than once from a dict.
cur = self.conn.cursor() with self.conn.cursor() as cur:
foo = (lambda x: x)('foo') * 10 foo = (lambda x: x)('foo') * 10
nref1 = sys.getrefcount(foo) nref1 = sys.getrefcount(foo)
cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo}) cur.mogrify("select %(foo)s, %(foo)s, %(foo)s", {'foo': foo})
nref2 = sys.getrefcount(foo) nref2 = sys.getrefcount(foo)
self.assertEqual(nref1, nref2) self.assertEqual(nref1, nref2)
def test_modify_closed(self): def test_modify_closed(self):
cur = self.conn.cursor() cur = self.conn.cursor()
@ -130,52 +129,51 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(sql, b"select 10") self.assertEqual(sql, b"select 10")
def test_bad_placeholder(self): def test_bad_placeholder(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo", {}) cur.mogrify, "select %(foo", {})
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo", {'foo': 1}) cur.mogrify, "select %(foo", {'foo': 1})
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo, %(bar)", {'foo': 1}) cur.mogrify, "select %(foo, %(bar)", {'foo': 1})
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2}) cur.mogrify, "select %(foo, %(bar)", {'foo': 1, 'bar': 2})
def test_cast(self): def test_cast(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
self.assertEqual(42, curs.cast(20, '42'))
self.assertAlmostEqual(3.14, curs.cast(700, '3.14'))
self.assertEqual(42, curs.cast(20, '42')) self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45'))
self.assertAlmostEqual(3.14, curs.cast(700, '3.14'))
self.assertEqual(Decimal('123.45'), curs.cast(1700, '123.45')) self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02'))
self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown
self.assertEqual(date(2011, 1, 2), curs.cast(1082, '2011-01-02'))
self.assertEqual("who am i?", curs.cast(705, 'who am i?')) # unknown
def test_cast_specificity(self): def test_cast_specificity(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
self.assertEqual("foo", curs.cast(705, 'foo')) self.assertEqual("foo", curs.cast(705, 'foo'))
D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2) D = psycopg2.extensions.new_type((705,), "DOUBLING", lambda v, c: v * 2)
psycopg2.extensions.register_type(D, self.conn) psycopg2.extensions.register_type(D, self.conn)
self.assertEqual("foofoo", curs.cast(705, 'foo')) self.assertEqual("foofoo", curs.cast(705, 'foo'))
T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3) T = psycopg2.extensions.new_type((705,), "TREBLING", lambda v, c: v * 3)
psycopg2.extensions.register_type(T, curs) psycopg2.extensions.register_type(T, curs)
self.assertEqual("foofoofoo", curs.cast(705, 'foo')) self.assertEqual("foofoofoo", curs.cast(705, 'foo'))
curs2 = self.conn.cursor() with self.conn.cursor() as curs2:
self.assertEqual("foofoo", curs2.cast(705, 'foo')) self.assertEqual("foofoo", curs2.cast(705, 'foo'))
def test_weakref(self): def test_weakref(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
w = ref(curs) w = ref(curs)
del curs del curs
gc.collect() gc.collect()
self.assert_(w() is None) self.assert_(w() is None)
def test_null_name(self): def test_null_name(self):
curs = self.conn.cursor(None) with self.conn.cursor(None) as curs:
self.assertEqual(curs.name, None) self.assertEqual(curs.name, None)
def test_invalid_name(self): def test_invalid_name(self):
curs = self.conn.cursor() curs = self.conn.cursor()
@ -184,9 +182,9 @@ class CursorTests(ConnectingTestCase):
curs.execute("insert into invname values (%s)", (i,)) curs.execute("insert into invname values (%s)", (i,))
curs.close() curs.close()
curs = self.conn.cursor(r'1-2-3 \ "test"') with self.conn.cursor(r'1-2-3 \ "test"') as curs:
curs.execute("select data from invname order by data") curs.execute("select data from invname order by data")
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
def _create_withhold_table(self): def _create_withhold_table(self):
curs = self.conn.cursor() curs = self.conn.cursor()
@ -213,15 +211,15 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs.close() curs.close()
curs = self.conn.cursor("W", withhold=True) with self.conn.cursor("W", withhold=True) as curs:
self.assertEqual(curs.withhold, True) self.assertEqual(curs.withhold, True)
curs.execute("select data from withhold order by data") curs.execute("select data from withhold order by data")
self.conn.commit() self.conn.commit()
self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)]) self.assertEqual(curs.fetchall(), [(10,), (20,), (30,)])
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("drop table withhold") curs.execute("drop table withhold")
self.conn.commit() self.conn.commit()
def test_withhold_no_begin(self): def test_withhold_no_begin(self):
self._create_withhold_table() self._create_withhold_table()
@ -328,110 +326,110 @@ class CursorTests(ConnectingTestCase):
return self.skipTest("can't evaluate non-scrollable cursor") return self.skipTest("can't evaluate non-scrollable cursor")
curs.close() curs.close()
curs = self.conn.cursor("S", scrollable=False) with self.conn.cursor("S", scrollable=False) as curs:
self.assertEqual(curs.scrollable, False) self.assertEqual(curs.scrollable, False)
curs.execute("select * from scrollable") curs.execute("select * from scrollable")
curs.scroll(2) curs.scroll(2)
self.assertRaises(psycopg2.OperationalError, curs.scroll, -1) self.assertRaises(psycopg2.OperationalError, curs.scroll, -1)
@slow @slow
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_iter_named_cursor_efficient(self): def test_iter_named_cursor_efficient(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
# if these records are fetched in the same roundtrip their # if these records are fetched in the same roundtrip their
# timestamp will not be influenced by the pause in Python world. # timestamp will not be influenced by the pause in Python world.
curs.execute("""select clock_timestamp() from generate_series(1,2)""") curs.execute("""select clock_timestamp() from generate_series(1,2)""")
i = iter(curs) i = iter(curs)
t1 = next(i)[0] t1 = next(i)[0]
time.sleep(0.2) time.sleep(0.2)
t2 = next(i)[0] t2 = next(i)[0]
self.assert_((t2 - t1).microseconds * 1e-6 < 0.1, self.assert_((t2 - t1).microseconds * 1e-6 < 0.1,
"named cursor records fetched in 2 roundtrips (delta: %s)" "named cursor records fetched in 2 roundtrips (delta: %s)"
% (t2 - t1)) % (t2 - t1))
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_default_itersize(self): def test_iter_named_cursor_default_itersize(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute('select generate_series(1,50)') curs.execute('select generate_series(1,50)')
rv = [(r[0], curs.rownumber) for r in curs] rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in one gulp # everything swallowed in one gulp
self.assertEqual(rv, [(i, i) for i in range(1, 51)]) self.assertEqual(rv, [(i, i) for i in range(1, 51)])
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_itersize(self): def test_iter_named_cursor_itersize(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.itersize = 30 curs.itersize = 30
curs.execute('select generate_series(1,50)') curs.execute('select generate_series(1,50)')
rv = [(r[0], curs.rownumber) for r in curs] rv = [(r[0], curs.rownumber) for r in curs]
# everything swallowed in two gulps # everything swallowed in two gulps
self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)]) self.assertEqual(rv, [(i, ((i - 1) % 30) + 1) for i in range(1, 51)])
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_iter_named_cursor_rownumber(self): def test_iter_named_cursor_rownumber(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
# note: this fails if itersize < dataset: internally we check # note: this fails if itersize < dataset: internally we check
# rownumber == rowcount to detect when to read anoter page, so we # rownumber == rowcount to detect when to read anoter page, so we
# would need an extra attribute to have a monotonic rownumber. # would need an extra attribute to have a monotonic rownumber.
curs.itersize = 20 curs.itersize = 20
curs.execute('select generate_series(1,10)') curs.execute('select generate_series(1,10)')
for i, rec in enumerate(curs): for i, rec in enumerate(curs):
self.assertEqual(i + 1, curs.rownumber) self.assertEqual(i + 1, curs.rownumber)
def test_description_attribs(self): def test_description_attribs(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("""select curs.execute("""select
3.14::decimal(10,2) as pi, 3.14::decimal(10,2) as pi,
'hello'::text as hi, 'hello'::text as hi,
'2010-02-18'::date as now; '2010-02-18'::date as now;
""") """)
self.assertEqual(len(curs.description), 3) self.assertEqual(len(curs.description), 3)
for c in curs.description: for c in curs.description:
self.assertEqual(len(c), 7) # DBAPI happy self.assertEqual(len(c), 7) # DBAPI happy
for a in ('name', 'type_code', 'display_size', 'internal_size', for a in ('name', 'type_code', 'display_size', 'internal_size',
'precision', 'scale', 'null_ok'): 'precision', 'scale', 'null_ok'):
self.assert_(hasattr(c, a), a) self.assert_(hasattr(c, a), a)
c = curs.description[0] c = curs.description[0]
self.assertEqual(c.name, 'pi') self.assertEqual(c.name, 'pi')
self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values) self.assert_(c.type_code in psycopg2.extensions.DECIMAL.values)
self.assert_(c.internal_size > 0) self.assert_(c.internal_size > 0)
self.assertEqual(c.precision, 10) self.assertEqual(c.precision, 10)
self.assertEqual(c.scale, 2) self.assertEqual(c.scale, 2)
c = curs.description[1] c = curs.description[1]
self.assertEqual(c.name, 'hi') self.assertEqual(c.name, 'hi')
self.assert_(c.type_code in psycopg2.STRING.values) self.assert_(c.type_code in psycopg2.STRING.values)
self.assert_(c.internal_size < 0) self.assert_(c.internal_size < 0)
self.assertEqual(c.precision, None) self.assertEqual(c.precision, None)
self.assertEqual(c.scale, None) self.assertEqual(c.scale, None)
c = curs.description[2] c = curs.description[2]
self.assertEqual(c.name, 'now') self.assertEqual(c.name, 'now')
self.assert_(c.type_code in psycopg2.extensions.DATE.values) self.assert_(c.type_code in psycopg2.extensions.DATE.values)
self.assert_(c.internal_size > 0) self.assert_(c.internal_size > 0)
self.assertEqual(c.precision, None) self.assertEqual(c.precision, None)
self.assertEqual(c.scale, None) self.assertEqual(c.scale, None)
def test_description_extra_attribs(self): def test_description_extra_attribs(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(""" curs.execute("""
create table testcol ( create table testcol (
pi decimal(10,2), pi decimal(10,2),
hi text) hi text)
""") """)
curs.execute("select oid from pg_class where relname = %s", ('testcol',)) curs.execute("select oid from pg_class where relname = %s", ('testcol',))
oid = curs.fetchone()[0] oid = curs.fetchone()[0]
curs.execute("insert into testcol values (3.14, 'hello')") curs.execute("insert into testcol values (3.14, 'hello')")
curs.execute("select hi, pi, 42 from testcol") curs.execute("select hi, pi, 42 from testcol")
self.assertEqual(curs.description[0].table_oid, oid) self.assertEqual(curs.description[0].table_oid, oid)
self.assertEqual(curs.description[0].table_column, 2) self.assertEqual(curs.description[0].table_column, 2)
self.assertEqual(curs.description[1].table_oid, oid) self.assertEqual(curs.description[1].table_oid, oid)
self.assertEqual(curs.description[1].table_column, 1) self.assertEqual(curs.description[1].table_column, 1)
self.assertEqual(curs.description[2].table_oid, None) self.assertEqual(curs.description[2].table_oid, None)
self.assertEqual(curs.description[2].table_column, None) self.assertEqual(curs.description[2].table_column, None)
def test_description_slice(self): def test_description_slice(self):
curs = self.conn.cursor() curs = self.conn.cursor()
@ -439,28 +437,28 @@ class CursorTests(ConnectingTestCase):
self.assertEqual(curs.description[0][0:2], ('a', 23)) self.assertEqual(curs.description[0][0:2], ('a', 23))
def test_pickle_description(self): def test_pickle_description(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('SELECT 1 AS foo') curs.execute('SELECT 1 AS foo')
description = curs.description description = curs.description
pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL) pickled = pickle.dumps(description, pickle.HIGHEST_PROTOCOL)
unpickled = pickle.loads(pickled) unpickled = pickle.loads(pickled)
self.assertEqual(description, unpickled) self.assertEqual(description, unpickled)
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_named_cursor_stealing(self): def test_named_cursor_stealing(self):
# you can use a named cursor to iterate on a refcursor created # you can use a named cursor to iterate on a refcursor created
# somewhere else # somewhere else
cur1 = self.conn.cursor() with self.conn.cursor() as cur1:
cur1.execute("DECLARE test CURSOR WITHOUT HOLD " cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)") " FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test') with self.conn.cursor('test') as cur2:
# can call fetch without execute # can call fetch without execute
self.assertEqual((1,), cur2.fetchone()) self.assertEqual((1,), cur2.fetchone())
self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3)) self.assertEqual([(2,), (3,), (4,)], cur2.fetchmany(3))
self.assertEqual([(5,), (6,), (7,)], cur2.fetchall()) self.assertEqual([(5,), (6,), (7,)], cur2.fetchall())
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_named_noop_close(self): def test_named_noop_close(self):
@ -469,63 +467,63 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_stolen_named_cursor_close(self): def test_stolen_named_cursor_close(self):
cur1 = self.conn.cursor() with self.conn.cursor() as cur1:
cur1.execute("DECLARE test CURSOR WITHOUT HOLD " cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)") " FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test') cur2 = self.conn.cursor('test')
cur2.close() cur2.close()
cur1.execute("DECLARE test CURSOR WITHOUT HOLD " cur1.execute("DECLARE test CURSOR WITHOUT HOLD "
" FOR SELECT generate_series(1,7)") " FOR SELECT generate_series(1,7)")
cur2 = self.conn.cursor('test') cur2 = self.conn.cursor('test')
cur2.close() cur2.close()
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_scroll(self): def test_scroll(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select generate_series(0,9)") cur.execute("select generate_series(0,9)")
cur.scroll(2) cur.scroll(2)
self.assertEqual(cur.fetchone(), (2,)) self.assertEqual(cur.fetchone(), (2,))
cur.scroll(2) cur.scroll(2)
self.assertEqual(cur.fetchone(), (5,)) self.assertEqual(cur.fetchone(), (5,))
cur.scroll(2, mode='relative') cur.scroll(2, mode='relative')
self.assertEqual(cur.fetchone(), (8,)) self.assertEqual(cur.fetchone(), (8,))
cur.scroll(-1) cur.scroll(-1)
self.assertEqual(cur.fetchone(), (8,)) self.assertEqual(cur.fetchone(), (8,))
cur.scroll(-2) cur.scroll(-2)
self.assertEqual(cur.fetchone(), (7,)) self.assertEqual(cur.fetchone(), (7,))
cur.scroll(2, mode='absolute') cur.scroll(2, mode='absolute')
self.assertEqual(cur.fetchone(), (2,)) self.assertEqual(cur.fetchone(), (2,))
# on the boundary # on the boundary
cur.scroll(0, mode='absolute') cur.scroll(0, mode='absolute')
self.assertEqual(cur.fetchone(), (0,)) self.assertEqual(cur.fetchone(), (0,))
self.assertRaises((IndexError, psycopg2.ProgrammingError), self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, -1, mode='absolute') cur.scroll, -1, mode='absolute')
cur.scroll(0, mode='absolute') cur.scroll(0, mode='absolute')
self.assertRaises((IndexError, psycopg2.ProgrammingError), self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, -1) cur.scroll, -1)
cur.scroll(9, mode='absolute') cur.scroll(9, mode='absolute')
self.assertEqual(cur.fetchone(), (9,)) self.assertEqual(cur.fetchone(), (9,))
self.assertRaises((IndexError, psycopg2.ProgrammingError), self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, 10, mode='absolute') cur.scroll, 10, mode='absolute')
cur.scroll(9, mode='absolute') cur.scroll(9, mode='absolute')
self.assertRaises((IndexError, psycopg2.ProgrammingError), self.assertRaises((IndexError, psycopg2.ProgrammingError),
cur.scroll, 1) cur.scroll, 1)
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_scroll_named(self): def test_scroll_named(self):
cur = self.conn.cursor('tmp', scrollable=True) with self.conn.cursor('tmp', scrollable=True) as cur:
cur.execute("select generate_series(0,9)") cur.execute("select generate_series(0,9)")
cur.scroll(2) cur.scroll(2)
self.assertEqual(cur.fetchone(), (2,)) self.assertEqual(cur.fetchone(), (2,))
cur.scroll(2) cur.scroll(2)
self.assertEqual(cur.fetchone(), (5,)) self.assertEqual(cur.fetchone(), (5,))
cur.scroll(2, mode='relative') cur.scroll(2, mode='relative')
self.assertEqual(cur.fetchone(), (8,)) self.assertEqual(cur.fetchone(), (8,))
cur.scroll(9, mode='absolute') cur.scroll(9, mode='absolute')
self.assertEqual(cur.fetchone(), (9,)) self.assertEqual(cur.fetchone(), (9,))
def test_bad_subclass(self): def test_bad_subclass(self):
# check that we get an error message instead of a segfault # check that we get an error message instead of a segfault
@ -536,14 +534,14 @@ class CursorTests(ConnectingTestCase):
# I am stupid so not calling superclass init # I am stupid so not calling superclass init
pass pass
cur = StupidCursor() with StupidCursor() as cur:
self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1') self.assertRaises(psycopg2.InterfaceError, cur.execute, 'select 1')
self.assertRaises(psycopg2.InterfaceError, cur.executemany, self.assertRaises(psycopg2.InterfaceError, cur.executemany,
'select 1', []) 'select 1', [])
def test_callproc_badparam(self): def test_callproc_badparam(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(TypeError, cur.callproc, 'lower', 42) self.assertRaises(TypeError, cur.callproc, 'lower', 42)
# It would be inappropriate to test callproc's named parameters in the # It would be inappropriate to test callproc's named parameters in the
# DBAPI2.0 test section because they are a psycopg2 extension. # DBAPI2.0 test section because they are a psycopg2 extension.
@ -556,32 +554,31 @@ class CursorTests(ConnectingTestCase):
escaped_paramname = '"%s"' % paramname.replace('"', '""') escaped_paramname = '"%s"' % paramname.replace('"', '""')
procname = 'pg_temp.randall' procname = 'pg_temp.randall'
cur = self.conn.cursor() with self.conn.cursor() as cur:
# Set up the temporary function
cur.execute('''
CREATE FUNCTION %s(%s INT)
RETURNS INT AS
'SELECT $1 * $1'
LANGUAGE SQL
''' % (procname, escaped_paramname))
# Set up the temporary function # Make sure callproc works right
cur.execute(''' cur.callproc(procname, {paramname: 2})
CREATE FUNCTION %s(%s INT) self.assertEquals(cur.fetchone()[0], 4)
RETURNS INT AS
'SELECT $1 * $1'
LANGUAGE SQL
''' % (procname, escaped_paramname))
# Make sure callproc works right # Make sure callproc fails right
cur.callproc(procname, {paramname: 2}) failing_cases = [
self.assertEquals(cur.fetchone()[0], 4) ({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError),
({paramname: '2'}, psycopg2.ProgrammingError),
# Make sure callproc fails right ({paramname: 'two'}, psycopg2.ProgrammingError),
failing_cases = [ ({u'bj\xc3rn': 2}, psycopg2.ProgrammingError),
({paramname: 2, 'foo': 'bar'}, psycopg2.ProgrammingError), ({3: 2}, TypeError),
({paramname: '2'}, psycopg2.ProgrammingError), ({self: 2}, TypeError),
({paramname: 'two'}, psycopg2.ProgrammingError), ]
({u'bj\xc3rn': 2}, psycopg2.ProgrammingError), for parameter_sequence, exception in failing_cases:
({3: 2}, TypeError), self.assertRaises(exception, cur.callproc, procname, parameter_sequence)
({self: 2}, TypeError), self.conn.rollback()
]
for parameter_sequence, exception in failing_cases:
self.assertRaises(exception, cur.callproc, procname, parameter_sequence)
self.conn.rollback()
@skip_if_no_superuser @skip_if_no_superuser
@skip_if_windows @skip_if_windows
@ -643,17 +640,17 @@ class CursorTests(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_rowcount_on_executemany_returning(self): def test_rowcount_on_executemany_returning(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("create table execmany(id serial primary key, data int)") cur.execute("create table execmany(id serial primary key, data int)")
cur.executemany( cur.executemany(
"insert into execmany (data) values (%s)", "insert into execmany (data) values (%s)",
[(i,) for i in range(4)]) [(i,) for i in range(4)])
self.assertEqual(cur.rowcount, 4) self.assertEqual(cur.rowcount, 4)
cur.executemany( cur.executemany(
"insert into execmany (data) values (%s) returning data", "insert into execmany (data) values (%s) returning data",
[(i,) for i in range(5)]) [(i,) for i in range(5)])
self.assertEqual(cur.rowcount, 5) self.assertEqual(cur.rowcount, 5)
@skip_before_postgres(9) @skip_before_postgres(9)
def test_pgresult_ptr(self): def test_pgresult_ptr(self):

View File

@ -369,22 +369,22 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
self.assertEqual(total_seconds(t), 1e-6) self.assertEqual(total_seconds(t), 1e-6)
def test_interval_overflow(self): def test_interval_overflow(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
# hack a cursor to receive values too extreme to be represented # hack a cursor to receive values too extreme to be represented
# but still I want an error, not a random number # but still I want an error, not a random number
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.new_type( psycopg2.extensions.new_type(
psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL),
cur) cur)
def f(val): def f(val):
cur.execute("select '%s'::text" % val) cur.execute("select '%s'::text" % val)
return cur.fetchone()[0] return cur.fetchone()[0]
self.assertRaises(OverflowError, f, '100000000000000000:00:00') self.assertRaises(OverflowError, f, '100000000000000000:00:00')
self.assertRaises(OverflowError, f, '00:100000000000000000:00:00') self.assertRaises(OverflowError, f, '00:100000000000000000:00:00')
self.assertRaises(OverflowError, f, '00:00:100000000000000000:00') self.assertRaises(OverflowError, f, '00:00:100000000000000000:00')
self.assertRaises(OverflowError, f, '00:00:00.100000000000000000') self.assertRaises(OverflowError, f, '00:00:00.100000000000000000')
def test_adapt_infinity_tz(self): def test_adapt_infinity_tz(self):
t = self.execute("select 'infinity'::timestamp") t = self.execute("select 'infinity'::timestamp")
@ -405,31 +405,31 @@ class DatetimeTests(ConnectingTestCase, CommonDatetimeTestsMixin):
def test_redshift_day(self): def test_redshift_day(self):
# Redshift is reported returning 1 day interval as microsec (bug #558) # Redshift is reported returning 1 day interval as microsec (bug #558)
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.new_type( psycopg2.extensions.new_type(
psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL), psycopg2.STRING.values, 'WAT', psycopg2.extensions.INTERVAL),
cur) cur)
for s, v in [ for s, v in [
('0', timedelta(0)), ('0', timedelta(0)),
('1', timedelta(microseconds=1)), ('1', timedelta(microseconds=1)),
('-1', timedelta(microseconds=-1)), ('-1', timedelta(microseconds=-1)),
('1000000', timedelta(seconds=1)), ('1000000', timedelta(seconds=1)),
('86400000000', timedelta(days=1)), ('86400000000', timedelta(days=1)),
('-86400000000', timedelta(days=-1)), ('-86400000000', timedelta(days=-1)),
]: ]:
cur.execute("select %s::text", (s,)) cur.execute("select %s::text", (s,))
r = cur.fetchone()[0] r = cur.fetchone()[0]
self.assertEqual(r, v, "%s -> %s != %s" % (s, r, v)) self.assertEqual(r, v, "%s -> %s != %s" % (s, r, v))
@skip_before_postgres(8, 4) @skip_before_postgres(8, 4)
def test_interval_iso_8601_not_supported(self): def test_interval_iso_8601_not_supported(self):
# We may end up supporting, but no pressure for it # We may end up supporting, but no pressure for it
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("set local intervalstyle to iso_8601") cur.execute("set local intervalstyle to iso_8601")
cur.execute("select '1 day 2 hours'::interval") cur.execute("select '1 day 2 hours'::interval")
self.assertRaises(psycopg2.NotSupportedError, cur.fetchone) self.assertRaises(psycopg2.NotSupportedError, cur.fetchone)
@unittest.skipUnless( @unittest.skipUnless(

View File

@ -33,9 +33,9 @@ from .testutils import ConnectingTestCase, skip_before_postgres, \
class _DictCursorBase(ConnectingTestCase): class _DictCursorBase(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)") curs.execute("CREATE TEMPORARY TABLE ExtrasDictCursorTests (foo text)")
curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')") curs.execute("INSERT INTO ExtrasDictCursorTests VALUES ('bar')")
self.conn.commit() self.conn.commit()
def _testIterRowNumber(self, curs): def _testIterRowNumber(self, curs):
@ -62,17 +62,20 @@ class _DictCursorBase(ConnectingTestCase):
class ExtrasDictCursorTests(_DictCursorBase): class ExtrasDictCursorTests(_DictCursorBase):
"""Test if DictCursor extension class works.""" """Test if DictCursor extension class works."""
@skip_before_postgres(8, 2)
def testDictConnCursorArgs(self): def testDictConnCursorArgs(self):
self.conn.close() self.conn.close()
self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection) self.conn = self.connect(connection_factory=psycopg2.extras.DictConnection)
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
self.assertEqual(cur.name, None) self.assertEqual(cur.name, None)
# overridable # overridable
cur = self.conn.cursor('foo', with self.conn.cursor(
cursor_factory=psycopg2.extras.NamedTupleCursor) 'foo',
self.assertEqual(cur.name, 'foo') cursor_factory=psycopg2.extras.NamedTupleCursor
self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor)) ) as cur:
self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.NamedTupleCursor))
def testDictCursorWithPlainCursorFetchOne(self): def testDictCursorWithPlainCursorFetchOne(self):
self._testWithPlainCursor(lambda curs: curs.fetchone()) self._testWithPlainCursor(lambda curs: curs.fetchone())
@ -100,13 +103,13 @@ class ExtrasDictCursorTests(_DictCursorBase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorWithPlainCursorIterRowNumber(self): def testDictCursorWithPlainCursorIterRowNumber(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
self._testIterRowNumber(curs) self._testIterRowNumber(curs)
def _testWithPlainCursor(self, getter): def _testWithPlainCursor(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar') self.failUnless(row[0] == 'bar')
return row return row
@ -131,33 +134,42 @@ class ExtrasDictCursorTests(_DictCursorBase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def testDictCursorWithNamedCursorNotGreedy(self): def testDictCursorWithNamedCursorNotGreedy(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
self._testNamedCursorNotGreedy(curs) 'tmp',
cursor_factory=psycopg2.extras.DictCursor
) as curs:
self._testNamedCursorNotGreedy(curs)
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorWithNamedCursorIterRowNumber(self): def testDictCursorWithNamedCursorIterRowNumber(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
self._testIterRowNumber(curs) 'tmp',
cursor_factory=psycopg2.extras.DictCursor
) as curs:
self._testIterRowNumber(curs)
def _testWithNamedCursor(self, getter): def _testWithNamedCursor(self, getter):
curs = self.conn.cursor('aname', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
curs.execute("SELECT * FROM ExtrasDictCursorTests") 'aname',
row = getter(curs) cursor_factory=psycopg2.extras.DictCursor
self.failUnless(row['foo'] == 'bar') ) as curs:
self.failUnless(row[0] == 'bar') curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
self.failUnless(row[0] == 'bar')
def testPickleDictRow(self): def testPickleDictRow(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
d = pickle.dumps(r) d = pickle.dumps(r)
r1 = pickle.loads(d) r1 = pickle.loads(d)
self.assertEqual(r, r1) self.assertEqual(r, r1)
self.assertEqual(r[0], r1[0]) self.assertEqual(r[0], r1[0])
self.assertEqual(r[1], r1[1]) self.assertEqual(r[1], r1[1])
self.assertEqual(r['a'], r1['a']) self.assertEqual(r['a'], r1['a'])
self.assertEqual(r['b'], r1['b']) self.assertEqual(r['b'], r1['b'])
self.assertEqual(r._index, r1._index) self.assertEqual(r._index, r1._index)
def test_copy(self): def test_copy(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor)
@ -175,50 +187,50 @@ class ExtrasDictCursorTests(_DictCursorBase):
@skip_from_python(3) @skip_from_python(3)
def test_iter_methods_2(self): def test_iter_methods_2(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(isinstance(r.keys(), list)) self.assert_(isinstance(r.keys(), list))
self.assertEqual(len(r.keys()), 2) self.assertEqual(len(r.keys()), 2)
self.assert_(isinstance(r.values(), tuple)) # sic? self.assert_(isinstance(r.values(), tuple)) # sic?
self.assertEqual(len(r.values()), 2) self.assertEqual(len(r.values()), 2)
self.assert_(isinstance(r.items(), list)) self.assert_(isinstance(r.items(), list))
self.assertEqual(len(r.items()), 2) self.assertEqual(len(r.items()), 2)
self.assert_(not isinstance(r.iterkeys(), list)) self.assert_(not isinstance(r.iterkeys(), list))
self.assertEqual(len(list(r.iterkeys())), 2) self.assertEqual(len(list(r.iterkeys())), 2)
self.assert_(not isinstance(r.itervalues(), list)) self.assert_(not isinstance(r.itervalues(), list))
self.assertEqual(len(list(r.itervalues())), 2) self.assertEqual(len(list(r.itervalues())), 2)
self.assert_(not isinstance(r.iteritems(), list)) self.assert_(not isinstance(r.iteritems(), list))
self.assertEqual(len(list(r.iteritems())), 2) self.assertEqual(len(list(r.iteritems())), 2)
@skip_before_python(3) @skip_before_python(3)
def test_iter_methods_3(self): def test_iter_methods_3(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(not isinstance(r.keys(), list)) self.assert_(not isinstance(r.keys(), list))
self.assertEqual(len(list(r.keys())), 2) self.assertEqual(len(list(r.keys())), 2)
self.assert_(not isinstance(r.values(), list)) self.assert_(not isinstance(r.values(), list))
self.assertEqual(len(list(r.values())), 2) self.assertEqual(len(list(r.values())), 2)
self.assert_(not isinstance(r.items(), list)) self.assert_(not isinstance(r.items(), list))
self.assertEqual(len(list(r.items())), 2) self.assertEqual(len(list(r.items())), 2)
def test_order(self): def test_order(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as curs:
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(list(r), [5, 4, 33, 2]) self.assertEqual(list(r), [5, 4, 33, 2])
self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux'])
self.assertEqual(list(r.values()), [5, 4, 33, 2]) self.assertEqual(list(r.values()), [5, 4, 33, 2])
self.assertEqual(list(r.items()), self.assertEqual(list(r.items()),
[('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)]) [('foo', 5), ('bar', 4), ('baz', 33), ('qux', 2)])
r1 = pickle.loads(pickle.dumps(r)) r1 = pickle.loads(pickle.dumps(r))
self.assertEqual(list(r1), list(r)) self.assertEqual(list(r1), list(r))
self.assertEqual(list(r1.keys()), list(r.keys())) self.assertEqual(list(r1.keys()), list(r.keys()))
self.assertEqual(list(r1.values()), list(r.values())) self.assertEqual(list(r1.values()), list(r.values()))
self.assertEqual(list(r1.items()), list(r.items())) self.assertEqual(list(r1.items()), list(r.items()))
@skip_from_python(3) @skip_from_python(3)
def test_order_iter(self): def test_order_iter(self):
@ -238,10 +250,10 @@ class ExtrasDictCursorTests(_DictCursorBase):
class ExtrasDictCursorRealTests(_DictCursorBase): class ExtrasDictCursorRealTests(_DictCursorBase):
def testRealMeansReal(self): def testRealMeansReal(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = curs.fetchone() row = curs.fetchone()
self.assert_(isinstance(row, dict)) self.assert_(isinstance(row, dict))
def testDictCursorWithPlainCursorRealFetchOne(self): def testDictCursorWithPlainCursorRealFetchOne(self):
self._testWithPlainCursorReal(lambda curs: curs.fetchone()) self._testWithPlainCursorReal(lambda curs: curs.fetchone())
@ -263,24 +275,24 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorWithPlainCursorRealIterRowNumber(self): def testDictCursorWithPlainCursorRealIterRowNumber(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
self._testIterRowNumber(curs) self._testIterRowNumber(curs)
def _testWithPlainCursorReal(self, getter): def _testWithPlainCursorReal(self, getter):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("SELECT * FROM ExtrasDictCursorTests") curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs) row = getter(curs)
self.failUnless(row['foo'] == 'bar') self.failUnless(row['foo'] == 'bar')
def testPickleRealDictRow(self): def testPickleRealDictRow(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
d = pickle.dumps(r) d = pickle.dumps(r)
r1 = pickle.loads(d) r1 = pickle.loads(d)
self.assertEqual(r, r1) self.assertEqual(r, r1)
self.assertEqual(r['a'], r1['a']) self.assertEqual(r['a'], r1['a'])
self.assertEqual(r['b'], r1['b']) self.assertEqual(r['b'], r1['b'])
def test_copy(self): def test_copy(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor)
@ -316,26 +328,34 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def testDictCursorRealWithNamedCursorNotGreedy(self): def testDictCursorRealWithNamedCursorNotGreedy(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(
self._testNamedCursorNotGreedy(curs) 'tmp',
cursor_factory=psycopg2.extras.RealDictCursor
) as curs:
self._testNamedCursorNotGreedy(curs)
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def testDictCursorRealWithNamedCursorIterRowNumber(self): def testDictCursorRealWithNamedCursorIterRowNumber(self):
curs = self.conn.cursor('tmp', cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(
self._testIterRowNumber(curs) 'tmp',
cursor_factory=psycopg2.extras.RealDictCursor
) as curs:
self._testIterRowNumber(curs)
def _testWithNamedCursorReal(self, getter): def _testWithNamedCursorReal(self, getter):
curs = self.conn.cursor('aname', with self.conn.cursor(
cursor_factory=psycopg2.extras.RealDictCursor) 'aname',
curs.execute("SELECT * FROM ExtrasDictCursorTests") cursor_factory=psycopg2.extras.RealDictCursor
row = getter(curs) ) as curs:
self.failUnless(row['foo'] == 'bar') curs.execute("SELECT * FROM ExtrasDictCursorTests")
row = getter(curs)
self.failUnless(row['foo'] == 'bar')
@skip_from_python(3) @skip_from_python(3)
def test_iter_methods_2(self): def test_iter_methods_2(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(isinstance(r.keys(), list)) self.assert_(isinstance(r.keys(), list))
self.assertEqual(len(r.keys()), 2) self.assertEqual(len(r.keys()), 2)
self.assert_(isinstance(r.values(), list)) self.assert_(isinstance(r.values(), list))
@ -352,9 +372,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_before_python(3) @skip_before_python(3)
def test_iter_methods_3(self): def test_iter_methods_3(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r = curs.fetchone() r = curs.fetchone()
self.assert_(not isinstance(r.keys(), list)) self.assert_(not isinstance(r.keys(), list))
self.assertEqual(len(list(r.keys())), 2) self.assertEqual(len(list(r.keys())), 2)
self.assert_(not isinstance(r.values(), list)) self.assert_(not isinstance(r.values(), list))
@ -363,9 +383,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
self.assertEqual(len(list(r.items())), 2) self.assertEqual(len(list(r.items())), 2)
def test_order(self): def test_order(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r), ['foo', 'bar', 'baz', 'qux'])
self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.keys()), ['foo', 'bar', 'baz', 'qux'])
self.assertEqual(list(r.values()), [5, 4, 33, 2]) self.assertEqual(list(r.values()), [5, 4, 33, 2])
@ -380,9 +400,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
@skip_from_python(3) @skip_from_python(3)
def test_order_iter(self): def test_order_iter(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux") curs.execute("select 5 as foo, 4 as bar, 33 as baz, 2 as qux")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(list(r.iterkeys()), ['foo', 'bar', 'baz', 'qux']) self.assertEqual(list(r.iterkeys()), ['foo', 'bar', 'baz', 'qux'])
self.assertEqual(list(r.itervalues()), [5, 4, 33, 2]) self.assertEqual(list(r.itervalues()), [5, 4, 33, 2])
self.assertEqual(list(r.iteritems()), self.assertEqual(list(r.iteritems()),
@ -394,9 +414,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
self.assertEqual(list(r1.iteritems()), list(r.iteritems())) self.assertEqual(list(r1.iteritems()), list(r.iteritems()))
def test_pop(self): def test_pop(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 1 as a, 2 as b, 3 as c") curs.execute("select 1 as a, 2 as b, 3 as c")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(r.pop('b'), 2) self.assertEqual(r.pop('b'), 2)
self.assertEqual(list(r), ['a', 'c']) self.assertEqual(list(r), ['a', 'c'])
self.assertEqual(list(r.keys()), ['a', 'c']) self.assertEqual(list(r.keys()), ['a', 'c'])
@ -407,9 +427,9 @@ class ExtrasDictCursorRealTests(_DictCursorBase):
self.assertRaises(KeyError, r.pop, 'b') self.assertRaises(KeyError, r.pop, 'b')
def test_mod(self): def test_mod(self):
curs = self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) with self.conn.cursor(cursor_factory=psycopg2.extras.RealDictCursor) as curs:
curs.execute("select 1 as a, 2 as b, 3 as c") curs.execute("select 1 as a, 2 as b, 3 as c")
r = curs.fetchone() r = curs.fetchone()
r['d'] = 4 r['d'] = 4
self.assertEqual(list(r), ['a', 'b', 'c', 'd']) self.assertEqual(list(r), ['a', 'b', 'c', 'd'])
self.assertEqual(list(r.keys()), ['a', 'b', 'c', 'd']) self.assertEqual(list(r.keys()), ['a', 'b', 'c', 'd'])
@ -428,137 +448,141 @@ class NamedTupleCursorTest(ConnectingTestCase):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
self.conn = self.connect(connection_factory=NamedTupleConnection) self.conn = self.connect(connection_factory=NamedTupleConnection)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)") curs.execute("CREATE TEMPORARY TABLE nttest (i int, s text)")
curs.execute("INSERT INTO nttest VALUES (1, 'foo')") curs.execute("INSERT INTO nttest VALUES (1, 'foo')")
curs.execute("INSERT INTO nttest VALUES (2, 'bar')") curs.execute("INSERT INTO nttest VALUES (2, 'bar')")
curs.execute("INSERT INTO nttest VALUES (3, 'baz')") curs.execute("INSERT INTO nttest VALUES (3, 'baz')")
self.conn.commit() self.conn.commit()
@skip_before_postgres(8, 2)
def test_cursor_args(self): def test_cursor_args(self):
cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.DictCursor) with self.conn.cursor(
self.assertEqual(cur.name, 'foo') 'foo',
self.assert_(isinstance(cur, psycopg2.extras.DictCursor)) cursor_factory=psycopg2.extras.DictCursor
) as cur:
self.assertEqual(cur.name, 'foo')
self.assert_(isinstance(cur, psycopg2.extras.DictCursor))
def test_fetchone(self): def test_fetchone(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
t = curs.fetchone() t = curs.fetchone()
self.assertEqual(t[0], 1) self.assertEqual(t[0], 1)
self.assertEqual(t.i, 1) self.assertEqual(t.i, 1)
self.assertEqual(t[1], 'foo') self.assertEqual(t[1], 'foo')
self.assertEqual(t.s, 'foo') self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1) self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_fetchmany_noarg(self): def test_fetchmany_noarg(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.arraysize = 2 curs.arraysize = 2
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
res = curs.fetchmany() res = curs.fetchmany()
self.assertEqual(2, len(res)) self.assertEqual(2, len(res))
self.assertEqual(res[0].i, 1) self.assertEqual(res[0].i, 1)
self.assertEqual(res[0].s, 'foo') self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2) self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar') self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2) self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_fetchmany(self): def test_fetchmany(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
res = curs.fetchmany(2) res = curs.fetchmany(2)
self.assertEqual(2, len(res)) self.assertEqual(2, len(res))
self.assertEqual(res[0].i, 1) self.assertEqual(res[0].i, 1)
self.assertEqual(res[0].s, 'foo') self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2) self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar') self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2) self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_fetchall(self): def test_fetchall(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
res = curs.fetchall() res = curs.fetchall()
self.assertEqual(3, len(res)) self.assertEqual(3, len(res))
self.assertEqual(res[0].i, 1) self.assertEqual(res[0].i, 1)
self.assertEqual(res[0].s, 'foo') self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2) self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar') self.assertEqual(res[1].s, 'bar')
self.assertEqual(res[2].i, 3) self.assertEqual(res[2].i, 3)
self.assertEqual(res[2].s, 'baz') self.assertEqual(res[2].s, 'baz')
self.assertEqual(curs.rownumber, 3) self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_executemany(self): def test_executemany(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.executemany("delete from nttest where i = %s", curs.executemany("delete from nttest where i = %s",
[(1,), (2,)]) [(1,), (2,)])
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
res = curs.fetchall() res = curs.fetchall()
self.assertEqual(1, len(res)) self.assertEqual(1, len(res))
self.assertEqual(res[0].i, 3) self.assertEqual(res[0].i, 3)
self.assertEqual(res[0].s, 'baz') self.assertEqual(res[0].s, 'baz')
def test_iter(self): def test_iter(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
i = iter(curs) i = iter(curs)
self.assertEqual(curs.rownumber, 0) self.assertEqual(curs.rownumber, 0)
t = next(i) t = next(i)
self.assertEqual(t.i, 1) self.assertEqual(t.i, 1)
self.assertEqual(t.s, 'foo') self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1) self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
t = next(i) t = next(i)
self.assertEqual(t.i, 2) self.assertEqual(t.i, 2)
self.assertEqual(t.s, 'bar') self.assertEqual(t.s, 'bar')
self.assertEqual(curs.rownumber, 2) self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
t = next(i) t = next(i)
self.assertEqual(t.i, 3) self.assertEqual(t.i, 3)
self.assertEqual(t.s, 'baz') self.assertEqual(t.s, 'baz')
self.assertRaises(StopIteration, next, i) self.assertRaises(StopIteration, next, i)
self.assertEqual(curs.rownumber, 3) self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3) self.assertEqual(curs.rowcount, 3)
def test_record_updated(self): def test_record_updated(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 1 as foo;") curs.execute("select 1 as foo;")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(r.foo, 1) self.assertEqual(r.foo, 1)
curs.execute("select 2 as bar;") curs.execute("select 2 as bar;")
r = curs.fetchone() r = curs.fetchone()
self.assertEqual(r.bar, 2) self.assertEqual(r.bar, 2)
self.assertRaises(AttributeError, getattr, r, 'foo') self.assertRaises(AttributeError, getattr, r, 'foo')
def test_no_result_no_surprise(self): def test_no_result_no_surprise(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("update nttest set s = s") curs.execute("update nttest set s = s")
self.assertRaises(psycopg2.ProgrammingError, curs.fetchone) self.assertRaises(psycopg2.ProgrammingError, curs.fetchone)
curs.execute("update nttest set s = s") curs.execute("update nttest set s = s")
self.assertRaises(psycopg2.ProgrammingError, curs.fetchall) self.assertRaises(psycopg2.ProgrammingError, curs.fetchall)
def test_bad_col_names(self): def test_bad_col_names(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"') curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"')
rv = curs.fetchone() rv = curs.fetchone()
self.assertEqual(rv.foo_bar_baz, 1) self.assertEqual(rv.foo_bar_baz, 1)
self.assertEqual(rv.f_column_, 2) self.assertEqual(rv.f_column_, 2)
self.assertEqual(rv.f3, 3) self.assertEqual(rv.f3, 3)
@skip_before_python(3) @skip_before_python(3)
@skip_before_postgres(8) @skip_before_postgres(8)
def test_nonascii_name(self): def test_nonascii_name(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('select 1 as \xe5h\xe9') curs.execute('select 1 as \xe5h\xe9')
rv = curs.fetchone() rv = curs.fetchone()
self.assertEqual(getattr(rv, '\xe5h\xe9'), 1) self.assertEqual(getattr(rv, '\xe5h\xe9'), 1)
def test_minimal_generation(self): def test_minimal_generation(self):
# Instrument the class to verify it gets called the minimum number of times. # Instrument the class to verify it gets called the minimum number of times.
@ -572,91 +596,92 @@ class NamedTupleCursorTest(ConnectingTestCase):
NamedTupleCursor._make_nt = f_patched NamedTupleCursor._make_nt = f_patched
try: try:
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
curs.fetchone() curs.fetchone()
curs.fetchone() curs.fetchone()
curs.fetchone() curs.fetchone()
self.assertEqual(1, calls[0]) self.assertEqual(1, calls[0])
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
curs.fetchone() curs.fetchone()
curs.fetchall() curs.fetchall()
self.assertEqual(2, calls[0]) self.assertEqual(2, calls[0])
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
curs.fetchone() curs.fetchone()
curs.fetchmany(1) curs.fetchmany(1)
self.assertEqual(3, calls[0]) self.assertEqual(3, calls[0])
finally: finally:
NamedTupleCursor._make_nt = f_orig NamedTupleCursor._make_nt = f_orig
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_named(self): def test_named(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select i from generate_series(0,9) i""") curs.execute("""select i from generate_series(0,9) i""")
recs = [] recs = []
recs.extend(curs.fetchmany(5)) recs.extend(curs.fetchmany(5))
recs.append(curs.fetchone()) recs.append(curs.fetchone())
recs.extend(curs.fetchall()) recs.extend(curs.fetchall())
self.assertEqual(list(range(10)), [t.i for t in recs]) self.assertEqual(list(range(10)), [t.i for t in recs])
def test_named_fetchone(self): def test_named_fetchone(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select 42 as i""") curs.execute("""select 42 as i""")
t = curs.fetchone() t = curs.fetchone()
self.assertEqual(t.i, 42) self.assertEqual(t.i, 42)
def test_named_fetchmany(self): def test_named_fetchmany(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select 42 as i""") curs.execute("""select 42 as i""")
recs = curs.fetchmany(10) recs = curs.fetchmany(10)
self.assertEqual(recs[0].i, 42) self.assertEqual(recs[0].i, 42)
def test_named_fetchall(self): def test_named_fetchall(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.execute("""select 42 as i""") curs.execute("""select 42 as i""")
recs = curs.fetchall() recs = curs.fetchall()
self.assertEqual(recs[0].i, 42) self.assertEqual(recs[0].i, 42)
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_not_greedy(self): def test_not_greedy(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
curs.itersize = 2 curs.itersize = 2
curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""") curs.execute(
recs = [] """select clock_timestamp() as ts from generate_series(1,3)""")
for t in curs: recs = []
time.sleep(0.01) for t in curs:
recs.append(t) time.sleep(0.01)
recs.append(t)
# check that the dataset was not fetched in a single gulp # check that the dataset was not fetched in a single gulp
self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005))
self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099))
@skip_before_postgres(8, 0) @skip_before_postgres(8, 0)
def test_named_rownumber(self): def test_named_rownumber(self):
curs = self.conn.cursor('tmp') with self.conn.cursor('tmp') as curs:
# Only checking for dataset < itersize: # Only checking for dataset < itersize:
# see CursorTests.test_iter_named_cursor_rownumber # see CursorTests.test_iter_named_cursor_rownumber
curs.itersize = 4 curs.itersize = 4
curs.execute("""select * from generate_series(1,3)""") curs.execute("""select * from generate_series(1,3)""")
for i, t in enumerate(curs): for i, t in enumerate(curs):
self.assertEqual(i + 1, curs.rownumber) self.assertEqual(i + 1, curs.rownumber)
def test_cache(self): def test_cache(self):
NamedTupleCursor._cached_make_nt.cache_clear() NamedTupleCursor._cached_make_nt.cache_clear()
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 10 as a, 20 as b") curs.execute("select 10 as a, 20 as b")
r1 = curs.fetchone() r1 = curs.fetchone()
curs.execute("select 10 as a, 20 as c") curs.execute("select 10 as a, 20 as c")
r2 = curs.fetchone() r2 = curs.fetchone()
# Get a new cursor to check that the cache works across multiple ones # Get a new cursor to check that the cache works across multiple ones
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 10 as a, 30 as b") curs.execute("select 10 as a, 30 as b")
r3 = curs.fetchone() r3 = curs.fetchone()
self.assert_(type(r1) is type(r3)) self.assert_(type(r1) is type(r3))
self.assert_(type(r1) is not type(r2)) self.assert_(type(r1) is not type(r2))
@ -672,20 +697,20 @@ class NamedTupleCursorTest(ConnectingTestCase):
lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__) lru_cache(8)(NamedTupleCursor._cached_make_nt.__wrapped__)
try: try:
recs = [] recs = []
curs = self.conn.cursor() with self.conn.cursor() as curs:
for i in range(10): for i in range(10):
curs.execute("select 1 as f%s" % i) curs.execute("select 1 as f%s" % i)
recs.append(curs.fetchone()) recs.append(curs.fetchone())
# Still in cache # Still in cache
curs.execute("select 1 as f9") curs.execute("select 1 as f9")
rec = curs.fetchone() rec = curs.fetchone()
self.assert_(any(type(r) is type(rec) for r in recs)) self.assert_(any(type(r) is type(rec) for r in recs))
# Gone from cache # Gone from cache
curs.execute("select 1 as f0") curs.execute("select 1 as f0")
rec = curs.fetchone() rec = curs.fetchone()
self.assert_(all(type(r) is not type(rec) for r in recs)) self.assert_(all(type(r) is not type(rec) for r in recs))
finally: finally:
NamedTupleCursor._cached_make_nt = old_func NamedTupleCursor._cached_make_nt = old_func

View File

@ -46,219 +46,219 @@ class TestPaginate(unittest.TestCase):
class FastExecuteTestMixin(object): class FastExecuteTestMixin(object):
def setUp(self): def setUp(self):
super(FastExecuteTestMixin, self).setUp() super(FastExecuteTestMixin, self).setUp()
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("""create table testfast ( cur.execute("""create table testfast (
id serial primary key, date date, val int, data text)""") id serial primary key, date date, val int, data text)""")
class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase): class TestExecuteBatch(FastExecuteTestMixin, testutils.ConnectingTestCase):
def test_empty(self): def test_empty(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
[]) [])
cur.execute("select * from testfast order by id") cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), []) self.assertEqual(cur.fetchall(), [])
def test_one(self): def test_one(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
iter([(1, 10)])) iter([(1, 10)]))
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)]) self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self): def test_tuples(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, date, val) values (%s, %s, %s)", "insert into testfast (id, date, val) values (%s, %s, %s)",
((i, date(2017, 1, i + 1), i * 10) for i in range(10))) ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id") cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(), self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self): def test_many(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(1000))) ((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_composed(self): def test_composed(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
sql.SQL("insert into {0} (id, val) values (%s, %s)") sql.SQL("insert into {0} (id, val) values (%s, %s)")
.format(sql.Identifier('testfast')), .format(sql.Identifier('testfast')),
((i, i * 10) for i in range(1000))) ((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self): def test_pages(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, val) values (%s, %s)", "insert into testfast (id, val) values (%s, %s)",
((i, i * 10) for i in range(25)), ((i, i * 10) for i in range(25)),
page_size=10) page_size=10)
# last command was 5 statements # last command was 5 statements
self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4) self.assertEqual(sum(c == u';' for c in cur.query.decode('ascii')), 4)
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
@testutils.skip_before_postgres(8, 0) @testutils.skip_before_postgres(8, 0)
def test_unicode(self): def test_unicode(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
ext.register_type(ext.UNICODE, cur) ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603" snowman = u"\u2603"
# unicode in statement # unicode in statement
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(1, 'x')]) [(1, 'x')])
cur.execute("select id, data from testfast where id = 1") cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x')) self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data # unicode in data
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%s, %s)", "insert into testfast (id, data) values (%s, %s)",
[(2, snowman)]) [(2, snowman)])
cur.execute("select id, data from testfast where id = 2") cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman)) self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both # unicode in both
psycopg2.extras.execute_batch(cur, psycopg2.extras.execute_batch(cur,
"insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman, "insert into testfast (id, data) values (%%s, %%s) -- %s" % snowman,
[(3, snowman)]) [(3, snowman)])
cur.execute("select id, data from testfast where id = 3") cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman)) self.assertEqual(cur.fetchone(), (3, snowman))
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase): class TestExecuteValues(FastExecuteTestMixin, testutils.ConnectingTestCase):
def test_empty(self): def test_empty(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
[]) [])
cur.execute("select * from testfast order by id") cur.execute("select * from testfast order by id")
self.assertEqual(cur.fetchall(), []) self.assertEqual(cur.fetchall(), [])
def test_one(self): def test_one(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
iter([(1, 10)])) iter([(1, 10)]))
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(1, 10)]) self.assertEqual(cur.fetchall(), [(1, 10)])
def test_tuples(self): def test_tuples(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s", "insert into testfast (id, date, val) values %s",
((i, date(2017, 1, i + 1), i * 10) for i in range(10))) ((i, date(2017, 1, i + 1), i * 10) for i in range(10)))
cur.execute("select id, date, val from testfast order by id") cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(), self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_dicts(self): def test_dicts(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, date, val) values %s", "insert into testfast (id, date, val) values %s",
(dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar") (dict(id=i, date=date(2017, 1, i + 1), val=i * 10, foo="bar")
for i in range(10)), for i in range(10)),
template='(%(id)s, %(date)s, %(val)s)') template='(%(id)s, %(date)s, %(val)s)')
cur.execute("select id, date, val from testfast order by id") cur.execute("select id, date, val from testfast order by id")
self.assertEqual(cur.fetchall(), self.assertEqual(cur.fetchall(),
[(i, date(2017, 1, i + 1), i * 10) for i in range(10)]) [(i, date(2017, 1, i + 1), i * 10) for i in range(10)])
def test_many(self): def test_many(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
((i, i * 10) for i in range(1000))) ((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_composed(self): def test_composed(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
sql.SQL("insert into {0} (id, val) values %s") sql.SQL("insert into {0} (id, val) values %s")
.format(sql.Identifier('testfast')), .format(sql.Identifier('testfast')),
((i, i * 10) for i in range(1000))) ((i, i * 10) for i in range(1000)))
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(1000)])
def test_pages(self): def test_pages(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s", "insert into testfast (id, val) values %s",
((i, i * 10) for i in range(25)), ((i, i * 10) for i in range(25)),
page_size=10) page_size=10)
# last statement was 5 tuples (one parens is for the fields list) # last statement was 5 tuples (one parens is for the fields list)
self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6) self.assertEqual(sum(c == '(' for c in cur.query.decode('ascii')), 6)
cur.execute("select id, val from testfast order by id") cur.execute("select id, val from testfast order by id")
self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)]) self.assertEqual(cur.fetchall(), [(i, i * 10) for i in range(25)])
def test_unicode(self): def test_unicode(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
ext.register_type(ext.UNICODE, cur) ext.register_type(ext.UNICODE, cur)
snowman = u"\u2603" snowman = u"\u2603"
# unicode in statement # unicode in statement
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman, "insert into testfast (id, data) values %%s -- %s" % snowman,
[(1, 'x')]) [(1, 'x')])
cur.execute("select id, data from testfast where id = 1") cur.execute("select id, data from testfast where id = 1")
self.assertEqual(cur.fetchone(), (1, 'x')) self.assertEqual(cur.fetchone(), (1, 'x'))
# unicode in data # unicode in data
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s", "insert into testfast (id, data) values %s",
[(2, snowman)]) [(2, snowman)])
cur.execute("select id, data from testfast where id = 2") cur.execute("select id, data from testfast where id = 2")
self.assertEqual(cur.fetchone(), (2, snowman)) self.assertEqual(cur.fetchone(), (2, snowman))
# unicode in both # unicode in both
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %%s -- %s" % snowman, "insert into testfast (id, data) values %%s -- %s" % snowman,
[(3, snowman)]) [(3, snowman)])
cur.execute("select id, data from testfast where id = 3") cur.execute("select id, data from testfast where id = 3")
self.assertEqual(cur.fetchone(), (3, snowman)) self.assertEqual(cur.fetchone(), (3, snowman))
def test_returning(self): def test_returning(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
result = psycopg2.extras.execute_values(cur, result = psycopg2.extras.execute_values(cur,
"insert into testfast (id, val) values %s returning id", "insert into testfast (id, val) values %s returning id",
((i, i * 10) for i in range(25)), ((i, i * 10) for i in range(25)),
page_size=10, fetch=True) page_size=10, fetch=True)
# result contains all returned pages # result contains all returned pages
self.assertEqual([r[0] for r in result], list(range(25))) self.assertEqual([r[0] for r in result], list(range(25)))
def test_invalid_sql(self): def test_invalid_sql(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert", []) "insert", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %s and %s", []) "insert %s and %s", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %f", []) "insert %f", [])
self.assertRaises(ValueError, psycopg2.extras.execute_values, cur, self.assertRaises(ValueError, psycopg2.extras.execute_values, cur,
"insert %f %s", []) "insert %f %s", [])
def test_percent_escape(self): def test_percent_escape(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.execute_values(cur, psycopg2.extras.execute_values(cur,
"insert into testfast (id, data) values %s -- a%%b", "insert into testfast (id, data) values %s -- a%%b",
[(1, 'hi')]) [(1, 'hi')])
self.assert_(b'a%%b' not in cur.query) self.assert_(b'a%%b' not in cur.query)
self.assert_(b'a%b' in cur.query) self.assert_(b'a%b' in cur.query)
cur.execute("select id, data from testfast") cur.execute("select id, data from testfast")
self.assertEqual(cur.fetchall(), [(1, 'hi')]) self.assertEqual(cur.fetchall(), [(1, 'hi')])
def test_suite(): def test_suite():

View File

@ -71,14 +71,14 @@ class GreenTestCase(ConnectingTestCase):
# a very large query requires a flush loop to be sent to the backend # a very large query requires a flush loop to be sent to the backend
conn = self.conn conn = self.conn
stub = self.set_stub_wait_callback(conn) stub = self.set_stub_wait_callback(conn)
curs = conn.cursor() with conn.cursor() as curs:
for mb in 1, 5, 10, 20, 50: for mb in 1, 5, 10, 20, 50:
size = mb * 1024 * 1024 size = mb * 1024 * 1024
del stub.polls[:] del stub.polls[:]
curs.execute("select %s;", ('x' * size,)) curs.execute("select %s;", ('x' * size,))
self.assertEqual(size, len(curs.fetchone()[0])) self.assertEqual(size, len(curs.fetchone()[0]))
if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1: if stub.polls.count(psycopg2.extensions.POLL_WRITE) > 1:
return return
# This is more a testing glitch than an error: it happens # This is more a testing glitch than an error: it happens
# on high load on linux: probably because the kernel has more # on high load on linux: probably because the kernel has more
@ -105,21 +105,21 @@ class GreenTestCase(ConnectingTestCase):
# if there is an error in a green query, don't freak out and close # if there is an error in a green query, don't freak out and close
# the connection # the connection
conn = self.conn conn = self.conn
curs = conn.cursor() with conn.cursor() as curs:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
curs.execute, "select the unselectable") curs.execute, "select the unselectable")
# check that the connection is left in an usable state # check that the connection is left in an usable state
self.assert_(not conn.closed) self.assert_(not conn.closed)
conn.rollback() conn.rollback()
curs.execute("select 1") curs.execute("select 1")
self.assertEqual(curs.fetchone()[0], 1) self.assertEqual(curs.fetchone()[0], 1)
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_copy_no_hang(self): def test_copy_no_hang(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError, self.assertRaises(psycopg2.ProgrammingError,
cur.execute, "copy (select 1) to stdout") cur.execute, "copy (select 1) to stdout")
@slow @slow
@skip_before_postgres(9, 0) @skip_before_postgres(9, 0)
@ -137,19 +137,19 @@ class GreenTestCase(ConnectingTestCase):
raise conn.OperationalError("bad state from poll: %s" % state) raise conn.OperationalError("bad state from poll: %s" % state)
stub = self.set_stub_wait_callback(self.conn, wait) stub = self.set_stub_wait_callback(self.conn, wait)
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
select 1; select 1;
do $$ do $$
begin begin
raise notice 'hello'; raise notice 'hello';
end end
$$ language plpgsql; $$ language plpgsql;
select pg_sleep(1); select pg_sleep(1);
""") """)
polls = stub.polls.count(POLL_READ) polls = stub.polls.count(POLL_READ)
self.assert_(polls > 8, polls) self.assert_(polls > 8, polls)
class CallbackErrorTestCase(ConnectingTestCase): class CallbackErrorTestCase(ConnectingTestCase):
@ -203,16 +203,16 @@ class CallbackErrorTestCase(ConnectingTestCase):
for i in range(100): for i in range(100):
self.to_error = None self.to_error = None
cnn = self.connect() cnn = self.connect()
cur = cnn.cursor() with cnn.cursor() as cur:
self.to_error = i self.to_error = i
try: try:
cur.execute("select 1") cur.execute("select 1")
cur.fetchone() cur.fetchone()
except ZeroDivisionError: except ZeroDivisionError:
pass pass
else: else:
# The query completed # The query completed
return return
self.fail("you should have had a success or an error by now") self.fail("you should have had a success or an error by now")
@ -220,16 +220,19 @@ class CallbackErrorTestCase(ConnectingTestCase):
for i in range(100): for i in range(100):
self.to_error = None self.to_error = None
cnn = self.connect() cnn = self.connect()
cur = cnn.cursor('foo') with cnn.cursor('foo') as cur:
self.to_error = i self.to_error = i
try: try:
cur.execute("select 1") cur.execute("select 1")
cur.fetchone() cur.fetchone()
except ZeroDivisionError: except ZeroDivisionError:
pass pass
else: else:
# The query completed # The query completed
return return
finally:
# Don't raise an exception in the cursor context manager.
self.to_error = None
self.fail("you should have had a success or an error by now") self.fail("you should have had a success or an error by now")

View File

@ -33,82 +33,82 @@ except ImportError:
@unittest.skipIf(ip is None, "'ipaddress' module not available") @unittest.skipIf(ip is None, "'ipaddress' module not available")
class NetworkingTestCase(testutils.ConnectingTestCase): class NetworkingTestCase(testutils.ConnectingTestCase):
def test_inet_cast(self): def test_inet_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select null::inet") cur.execute("select null::inet")
self.assert_(cur.fetchone()[0] is None) self.assert_(cur.fetchone()[0] is None)
cur.execute("select '127.0.0.1/24'::inet") cur.execute("select '127.0.0.1/24'::inet")
obj = cur.fetchone()[0] obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj)) self.assert_(isinstance(obj, ip.IPv4Interface), repr(obj))
self.assertEquals(obj, ip.ip_interface('127.0.0.1/24')) self.assertEquals(obj, ip.ip_interface('127.0.0.1/24'))
cur.execute("select '::ffff:102:300/128'::inet") cur.execute("select '::ffff:102:300/128'::inet")
obj = cur.fetchone()[0] obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj)) self.assert_(isinstance(obj, ip.IPv6Interface), repr(obj))
self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128')) self.assertEquals(obj, ip.ip_interface('::ffff:102:300/128'))
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def test_inet_array_cast(self): def test_inet_array_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]") cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::inet[]")
l = cur.fetchone()[0] l = cur.fetchone()[0]
self.assert_(l[0] is None) self.assert_(l[0] is None)
self.assertEquals(l[1], ip.ip_interface('127.0.0.1')) self.assertEquals(l[1], ip.ip_interface('127.0.0.1'))
self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128')) self.assertEquals(l[2], ip.ip_interface('::ffff:102:300/128'))
self.assert_(isinstance(l[1], ip.IPv4Interface), l) self.assert_(isinstance(l[1], ip.IPv4Interface), l)
self.assert_(isinstance(l[2], ip.IPv6Interface), l) self.assert_(isinstance(l[2], ip.IPv6Interface), l)
def test_inet_adapt(self): def test_inet_adapt(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')]) cur.execute("select %s", [ip.ip_interface('127.0.0.1/24')])
self.assertEquals(cur.fetchone()[0], '127.0.0.1/24') self.assertEquals(cur.fetchone()[0], '127.0.0.1/24')
cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')]) cur.execute("select %s", [ip.ip_interface('::ffff:102:300/128')])
self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128')
def test_cidr_cast(self): def test_cidr_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select null::cidr") cur.execute("select null::cidr")
self.assert_(cur.fetchone()[0] is None) self.assert_(cur.fetchone()[0] is None)
cur.execute("select '127.0.0.0/24'::cidr") cur.execute("select '127.0.0.0/24'::cidr")
obj = cur.fetchone()[0] obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv4Network), repr(obj)) self.assert_(isinstance(obj, ip.IPv4Network), repr(obj))
self.assertEquals(obj, ip.ip_network('127.0.0.0/24')) self.assertEquals(obj, ip.ip_network('127.0.0.0/24'))
cur.execute("select '::ffff:102:300/128'::cidr") cur.execute("select '::ffff:102:300/128'::cidr")
obj = cur.fetchone()[0] obj = cur.fetchone()[0]
self.assert_(isinstance(obj, ip.IPv6Network), repr(obj)) self.assert_(isinstance(obj, ip.IPv6Network), repr(obj))
self.assertEquals(obj, ip.ip_network('::ffff:102:300/128')) self.assertEquals(obj, ip.ip_network('::ffff:102:300/128'))
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def test_cidr_array_cast(self): def test_cidr_array_cast(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]") cur.execute("select '{NULL,127.0.0.1,::ffff:102:300/128}'::cidr[]")
l = cur.fetchone()[0] l = cur.fetchone()[0]
self.assert_(l[0] is None) self.assert_(l[0] is None)
self.assertEquals(l[1], ip.ip_network('127.0.0.1')) self.assertEquals(l[1], ip.ip_network('127.0.0.1'))
self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128')) self.assertEquals(l[2], ip.ip_network('::ffff:102:300/128'))
self.assert_(isinstance(l[1], ip.IPv4Network), l) self.assert_(isinstance(l[1], ip.IPv4Network), l)
self.assert_(isinstance(l[2], ip.IPv6Network), l) self.assert_(isinstance(l[2], ip.IPv6Network), l)
def test_cidr_adapt(self): def test_cidr_adapt(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
psycopg2.extras.register_ipaddress(cur) psycopg2.extras.register_ipaddress(cur)
cur.execute("select %s", [ip.ip_network('127.0.0.0/24')]) cur.execute("select %s", [ip.ip_network('127.0.0.0/24')])
self.assertEquals(cur.fetchone()[0], '127.0.0.0/24') self.assertEquals(cur.fetchone()[0], '127.0.0.0/24')
cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')]) cur.execute("select %s", [ip.ip_network('::ffff:102:300/128')])
self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128') self.assertEquals(cur.fetchone()[0], '::ffff:102:300/128')
def test_suite(): def test_suite():

View File

@ -154,22 +154,22 @@ class ConnectTestCase(unittest.TestCase):
class ExceptionsTestCase(ConnectingTestCase): class ExceptionsTestCase(ConnectingTestCase):
def test_attributes(self): def test_attributes(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select * from nonexist") cur.execute("select * from nonexist")
except psycopg2.Error as exc: except psycopg2.Error as exc:
e = exc e = exc
self.assertEqual(e.pgcode, '42P01') self.assertEqual(e.pgcode, '42P01')
self.assert_(e.pgerror) self.assert_(e.pgerror)
self.assert_(e.cursor is cur) self.assert_(e.cursor is cur)
def test_diagnostics_attributes(self): def test_diagnostics_attributes(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select * from nonexist") cur.execute("select * from nonexist")
except psycopg2.Error as exc: except psycopg2.Error as exc:
e = exc e = exc
diag = e.diag diag = e.diag
self.assert_(isinstance(diag, psycopg2.extensions.Diagnostics)) self.assert_(isinstance(diag, psycopg2.extensions.Diagnostics))
@ -195,11 +195,11 @@ class ExceptionsTestCase(ConnectingTestCase):
def test_diagnostics_life(self): def test_diagnostics_life(self):
def tmp(): def tmp():
cur = self.conn.cursor() with self.conn.cursor() as cur:
try: try:
cur.execute("select * from nonexist") cur.execute("select * from nonexist")
except psycopg2.Error as exc: except psycopg2.Error as exc:
return cur, exc return cur, exc
cur, e = tmp() cur, e = tmp()
diag = e.diag diag = e.diag

View File

@ -120,10 +120,11 @@ conn.close()
self.listen('foo') self.listen('foo')
pid = int(self.notify('foo').communicate()[0]) pid = int(self.notify('foo').communicate()[0])
self.assertEqual(0, len(self.conn.notifies)) self.assertEqual(0, len(self.conn.notifies))
self.conn.cursor().execute('select 1;') with self.conn.cursor() as cur:
self.assertEqual(1, len(self.conn.notifies)) cur.execute('select 1;')
self.assertEqual(pid, self.conn.notifies[0][0]) self.assertEqual(1, len(self.conn.notifies))
self.assertEqual('foo', self.conn.notifies[0][1]) self.assertEqual(pid, self.conn.notifies[0][0])
self.assertEqual('foo', self.conn.notifies[0][1])
@slow @slow
def test_notify_object(self): def test_notify_object(self):

View File

@ -56,24 +56,24 @@ class QuotingTestCase(ConnectingTestCase):
""" """
data += "".join(map(chr, range(1, 127))) data += "".join(map(chr, range(1, 127)))
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("SELECT %s;", (data,)) curs.execute("SELECT %s;", (data,))
res = curs.fetchone()[0] res = curs.fetchone()[0]
self.assertEqual(res, data) self.assertEqual(res, data)
self.assert_(not self.conn.notices) self.assert_(not self.conn.notices)
def test_string_null_terminator(self): def test_string_null_terminator(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
data = 'abcd\x01\x00cdefg' data = 'abcd\x01\x00cdefg'
try: try:
curs.execute("SELECT %s", (data,)) curs.execute("SELECT %s", (data,))
except ValueError as e: except ValueError as e:
self.assertEquals(str(e), self.assertEquals(str(e),
'A string literal cannot contain NUL (0x00) characters.') 'A string literal cannot contain NUL (0x00) characters.')
else: else:
self.fail("ValueError not raised") self.fail("ValueError not raised")
def test_binary(self): def test_binary(self):
data = b"""some data with \000\013 binary data = b"""some data with \000\013 binary
@ -84,12 +84,12 @@ class QuotingTestCase(ConnectingTestCase):
else: else:
data += bytes(list(range(256))) data += bytes(list(range(256)))
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),)) curs.execute("SELECT %s::bytea;", (psycopg2.Binary(data),))
if PY2: if PY2:
res = str(curs.fetchone()[0]) res = str(curs.fetchone()[0])
else: else:
res = curs.fetchone()[0].tobytes() res = curs.fetchone()[0].tobytes()
if res[0] in (b'x', ord(b'x')) and self.conn.info.server_version >= 90000: if res[0] in (b'x', ord(b'x')) and self.conn.info.server_version >= 90000:
return self.skipTest( return self.skipTest(
@ -99,86 +99,87 @@ class QuotingTestCase(ConnectingTestCase):
self.assert_(not self.conn.notices) self.assert_(not self.conn.notices)
def test_unicode(self): def test_unicode(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("SHOW server_encoding") curs.execute("SHOW server_encoding")
server_encoding = curs.fetchone()[0] server_encoding = curs.fetchone()[0]
if server_encoding != "UTF8": if server_encoding != "UTF8":
return self.skipTest( return self.skipTest(
"Unicode test skipped since server encoding is %s" "Unicode test skipped since server encoding is %s"
% server_encoding) % server_encoding)
data = u"""some data with \t chars data = u"""some data with \t chars
to escape into, 'quotes', \u20ac euro sign and \\ a backslash too. to escape into, 'quotes', \u20ac euro sign and \\ a backslash too.
""" """
data += u"".join(map(unichr, [u for u in range(1, 65536) data += u"".join(map(unichr, [u for u in range(1, 65536)
if not 0xD800 <= u <= 0xDFFF])) # surrogate area if not 0xD800 <= u <= 0xDFFF])) # surrogate area
self.conn.set_client_encoding('UNICODE') self.conn.set_client_encoding('UNICODE')
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn) psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn)
curs.execute("SELECT %s::text;", (data,)) curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0] res = curs.fetchone()[0]
self.assertEqual(res, data) self.assertEqual(res, data)
self.assert_(not self.conn.notices) self.assert_(not self.conn.notices)
def test_latin1(self): def test_latin1(self):
self.conn.set_client_encoding('LATIN1') self.conn.set_client_encoding('LATIN1')
curs = self.conn.cursor() with self.conn.cursor() as curs:
if PY2: if PY2:
data = ''.join(map(chr, range(32, 127) + range(160, 256))) data = ''.join(map(chr, range(32, 127) + range(160, 256)))
else: else:
data = bytes(list(range(32, 127)) data = bytes(list(range(32, 127))
+ list(range(160, 256))).decode('latin1') + list(range(160, 256))).decode('latin1')
# as string
curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
# as unicode
if PY2:
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn)
data = data.decode('latin1')
# as string
curs.execute("SELECT %s::text;", (data,)) curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0] res = curs.fetchone()[0]
self.assertEqual(res, data) self.assertEqual(res, data)
self.assert_(not self.conn.notices) self.assert_(not self.conn.notices)
# as unicode
if PY2:
psycopg2.extensions.register_type(
psycopg2.extensions.UNICODE, self.conn)
data = data.decode('latin1')
curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
def test_koi8(self): def test_koi8(self):
self.conn.set_client_encoding('KOI8') self.conn.set_client_encoding('KOI8')
curs = self.conn.cursor() with self.conn.cursor() as curs:
if PY2: if PY2:
data = ''.join(map(chr, range(32, 127) + range(128, 256))) data = ''.join(map(chr, range(32, 127) + range(128, 256)))
else: else:
data = bytes(list(range(32, 127)) data = bytes(list(range(32, 127))
+ list(range(128, 256))).decode('koi8_r') + list(range(128, 256))).decode('koi8_r')
# as string
curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
# as unicode
if PY2:
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn)
data = data.decode('koi8_r')
# as string
curs.execute("SELECT %s::text;", (data,)) curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0] res = curs.fetchone()[0]
self.assertEqual(res, data) self.assertEqual(res, data)
self.assert_(not self.conn.notices) self.assert_(not self.conn.notices)
# as unicode
if PY2:
psycopg2.extensions.register_type(psycopg2.extensions.UNICODE, self.conn)
data = data.decode('koi8_r')
curs.execute("SELECT %s::text;", (data,))
res = curs.fetchone()[0]
self.assertEqual(res, data)
self.assert_(not self.conn.notices)
def test_bytes(self): def test_bytes(self):
snowman = u"\u2603" snowman = u"\u2603"
conn = self.connect() conn = self.connect()
conn.set_client_encoding('UNICODE') conn.set_client_encoding('UNICODE')
psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn) psycopg2.extensions.register_type(psycopg2.extensions.BYTES, conn)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("select %s::text", (snowman,)) curs.execute("select %s::text", (snowman,))
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x, bytes)) self.assert_(isinstance(x, bytes))
self.assertEqual(x, snowman.encode('utf8')) self.assertEqual(x, snowman.encode('utf8'))

View File

@ -73,14 +73,13 @@ class ReplicationTestCase(ConnectingTestCase):
conn = self.connect() conn = self.connect()
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
try:
try: cur.execute("DROP TABLE dummy1")
cur.execute("DROP TABLE dummy1") except psycopg2.ProgrammingError:
except psycopg2.ProgrammingError: conn.rollback()
conn.rollback() cur.execute(
cur.execute( "CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
"CREATE TABLE dummy1 AS SELECT * FROM generate_series(1, 5) AS id")
conn.commit() conn.commit()
@ -90,9 +89,9 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@skip_before_postgres(9, 0) @skip_before_postgres(9, 0)
def test_datestyle(self): def test_datestyle(self):
@ -104,29 +103,28 @@ class ReplicationTest(ReplicationTestCase):
connection_factory=PhysicalReplicationConnection) connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@skip_before_postgres(9, 4) @skip_before_postgres(9, 4)
def test_logical_replication_connection(self): def test_logical_replication_connection(self):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
cur.execute("IDENTIFY_SYSTEM") cur.execute("IDENTIFY_SYSTEM")
cur.fetchall() cur.fetchall()
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
def test_create_replication_slot(self): def test_create_replication_slot(self):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur)
self.create_replication_slot(cur) self.assertRaises(
self.assertRaises( psycopg2.ProgrammingError, self.create_replication_slot, cur)
psycopg2.ProgrammingError, self.create_replication_slot, cur)
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green @skip_repl_if_green
@ -134,13 +132,12 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=PhysicalReplicationConnection) conn = self.repl_connect(connection_factory=PhysicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.assertRaises(psycopg2.ProgrammingError,
cur.start_replication, self.slot)
self.assertRaises(psycopg2.ProgrammingError, self.create_replication_slot(cur)
cur.start_replication, self.slot) cur.start_replication(self.slot)
self.create_replication_slot(cur)
cur.start_replication(self.slot)
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green @skip_repl_if_green
@ -148,12 +145,11 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding')
self.create_replication_slot(cur, output_plugin='test_decoding') cur.start_replication_expert(
cur.start_replication_expert( sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format(
sql.SQL("START_REPLICATION SLOT {slot} LOGICAL 0/00000000").format( slot=sql.Identifier(self.slot)))
slot=sql.Identifier(self.slot)))
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green @skip_repl_if_green
@ -161,23 +157,22 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding')
self.make_replication_events()
self.create_replication_slot(cur, output_plugin='test_decoding') def consume(msg):
self.make_replication_events() raise StopReplication()
def consume(msg): with self.assertRaises(psycopg2.DataError):
raise StopReplication() # try with invalid options
cur.start_replication(
slot_name=self.slot, options={'invalid_param': 'value'})
cur.consume_stream(consume)
with self.assertRaises(psycopg2.DataError): # try with correct command
# try with invalid options cur.start_replication(slot_name=self.slot)
cur.start_replication( self.assertRaises(StopReplication, cur.consume_stream, consume)
slot_name=self.slot, options={'invalid_param': 'value'})
cur.consume_stream(consume)
# try with correct command
cur.start_replication(slot_name=self.slot)
self.assertRaises(StopReplication, cur.consume_stream, consume)
@skip_before_postgres(9, 4) # slots require 9.4 @skip_before_postgres(9, 4) # slots require 9.4
@skip_repl_if_green @skip_repl_if_green
@ -208,17 +203,16 @@ class ReplicationTest(ReplicationTestCase):
conn = self.repl_connect(connection_factory=LogicalReplicationConnection) conn = self.repl_connect(connection_factory=LogicalReplicationConnection)
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding')
self.create_replication_slot(cur, output_plugin='test_decoding') self.make_replication_events()
self.make_replication_events() cur.start_replication(self.slot)
cur.start_replication(self.slot) def consume(msg):
raise StopReplication()
def consume(msg): self.assertRaises(StopReplication, cur.consume_stream, consume)
raise StopReplication()
self.assertRaises(StopReplication, cur.consume_stream, consume)
class AsyncReplicationTest(ReplicationTestCase): class AsyncReplicationTest(ReplicationTestCase):
@ -230,42 +224,41 @@ class AsyncReplicationTest(ReplicationTestCase):
if conn is None: if conn is None:
return return
cur = conn.cursor() with conn.cursor() as cur:
self.create_replication_slot(cur, output_plugin='test_decoding')
self.wait(cur)
self.create_replication_slot(cur, output_plugin='test_decoding') cur.start_replication(self.slot)
self.wait(cur) self.wait(cur)
cur.start_replication(self.slot) self.make_replication_events()
self.wait(cur)
self.make_replication_events() self.msg_count = 0
self.msg_count = 0 def consume(msg):
# just check the methods
"%s: %s" % (cur.io_timestamp, repr(msg))
"%s: %s" % (cur.feedback_timestamp, repr(msg))
"%s: %s" % (cur.wal_end, repr(msg))
def consume(msg): self.msg_count += 1
# just check the methods if self.msg_count > 3:
"%s: %s" % (cur.io_timestamp, repr(msg)) cur.send_feedback(reply=True)
"%s: %s" % (cur.feedback_timestamp, repr(msg)) raise StopReplication()
"%s: %s" % (cur.wal_end, repr(msg))
self.msg_count += 1 cur.send_feedback(flush_lsn=msg.data_start)
if self.msg_count > 3:
cur.send_feedback(reply=True)
raise StopReplication()
cur.send_feedback(flush_lsn=msg.data_start) # cannot be used in asynchronous mode
self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume)
# cannot be used in asynchronous mode def process_stream():
self.assertRaises(psycopg2.ProgrammingError, cur.consume_stream, consume) while True:
msg = cur.read_message()
def process_stream(): if msg:
while True: consume(msg)
msg = cur.read_message() else:
if msg: select([cur], [], [])
consume(msg) self.assertRaises(StopReplication, process_stream)
else:
select([cur], [], [])
self.assertRaises(StopReplication, process_stream)
def test_suite(): def test_suite():

View File

@ -117,61 +117,61 @@ class SqlFormatTests(ConnectingTestCase):
sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn) sql.SQL("select {0};").format(sql.Literal(Foo())).as_string, self.conn)
def test_execute(self): def test_execute(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
create table test_compose ( create table test_compose (
id serial primary key, id serial primary key,
foo text, bar text, "ba'z" text) foo text, bar text, "ba'z" text)
""") """)
cur.execute( cur.execute(
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
sql.Identifier('test_compose'), sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
(sql.Placeholder() * 3).join(', ')), (sql.Placeholder() * 3).join(', ')),
(10, 'a', 'b', 'c')) (10, 'a', 'b', 'c'))
cur.execute("select * from test_compose") cur.execute("select * from test_compose")
self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')]) self.assertEqual(cur.fetchall(), [(10, 'a', 'b', 'c')])
def test_executemany(self): def test_executemany(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
create table test_compose ( create table test_compose (
id serial primary key, id serial primary key,
foo text, bar text, "ba'z" text) foo text, bar text, "ba'z" text)
""") """)
cur.executemany( cur.executemany(
sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format( sql.SQL("insert into {0} (id, {1}) values (%s, {2})").format(
sql.Identifier('test_compose'), sql.Identifier('test_compose'),
sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])), sql.SQL(', ').join(map(sql.Identifier, ['foo', 'bar', "ba'z"])),
(sql.Placeholder() * 3).join(', ')), (sql.Placeholder() * 3).join(', ')),
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
cur.execute("select * from test_compose") cur.execute("select * from test_compose")
self.assertEqual(cur.fetchall(), self.assertEqual(cur.fetchall(),
[(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')]) [(10, 'a', 'b', 'c'), (20, 'd', 'e', 'f')])
@skip_copy_if_green @skip_copy_if_green
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_copy(self): def test_copy(self):
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute(""" cur.execute("""
create table test_compose ( create table test_compose (
id serial primary key, id serial primary key,
foo text, bar text, "ba'z" text) foo text, bar text, "ba'z" text)
""") """)
s = StringIO("10\ta\tb\tc\n20\td\te\tf\n") s = StringIO("10\ta\tb\tc\n20\td\te\tf\n")
cur.copy_expert( cur.copy_expert(
sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format( sql.SQL("copy {t} (id, foo, bar, {f}) from stdin").format(
t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s) t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s)
s1 = StringIO() s1 = StringIO()
cur.copy_expert( cur.copy_expert(
sql.SQL("copy (select {f} from {t} order by id) to stdout").format( sql.SQL("copy (select {f} from {t} order by id) to stdout").format(
t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1) t=sql.Identifier("test_compose"), f=sql.Identifier("ba'z")), s1)
s1.seek(0) s1.seek(0)
self.assertEqual(s1.read(), 'c\nf\n') self.assertEqual(s1.read(), 'c\nf\n')
class IdentifierTests(ConnectingTestCase): class IdentifierTests(ConnectingTestCase):

View File

@ -37,59 +37,59 @@ class TransactionTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE) self.conn.set_isolation_level(ISOLATION_LEVEL_SERIALIZABLE)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE table1 ( CREATE TEMPORARY TABLE table1 (
id int PRIMARY KEY id int PRIMARY KEY
)''') )''')
# The constraint is set to deferrable for the commit_failed test # The constraint is set to deferrable for the commit_failed test
curs.execute(''' curs.execute('''
CREATE TEMPORARY TABLE table2 ( CREATE TEMPORARY TABLE table2 (
id int PRIMARY KEY, id int PRIMARY KEY,
table1_id int, table1_id int,
CONSTRAINT table2__table1_id__fk CONSTRAINT table2__table1_id__fk
FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''') FOREIGN KEY (table1_id) REFERENCES table1(id) DEFERRABLE)''')
curs.execute('INSERT INTO table1 VALUES (1)') curs.execute('INSERT INTO table1 VALUES (1)')
curs.execute('INSERT INTO table2 VALUES (1, 1)') curs.execute('INSERT INTO table2 VALUES (1, 1)')
self.conn.commit() self.conn.commit()
def test_rollback(self): def test_rollback(self):
# Test that rollback undoes changes # Test that rollback undoes changes
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('INSERT INTO table2 VALUES (2, 1)') curs.execute('INSERT INTO table2 VALUES (2, 1)')
# Rollback takes us from BEGIN state to READY state # Rollback takes us from BEGIN state to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN) self.assertEqual(self.conn.status, STATUS_BEGIN)
self.conn.rollback() self.conn.rollback()
self.assertEqual(self.conn.status, STATUS_READY) self.assertEqual(self.conn.status, STATUS_READY)
curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2')
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
def test_commit(self): def test_commit(self):
# Test that commit stores changes # Test that commit stores changes
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('INSERT INTO table2 VALUES (2, 1)') curs.execute('INSERT INTO table2 VALUES (2, 1)')
# Rollback takes us from BEGIN state to READY state # Rollback takes us from BEGIN state to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN) self.assertEqual(self.conn.status, STATUS_BEGIN)
self.conn.commit() self.conn.commit()
self.assertEqual(self.conn.status, STATUS_READY) self.assertEqual(self.conn.status, STATUS_READY)
# Now rollback and show that the new record is still there: # Now rollback and show that the new record is still there:
self.conn.rollback() self.conn.rollback()
curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2') curs.execute('SELECT id, table1_id FROM table2 WHERE id = 2')
self.assertEqual(curs.fetchall(), [(2, 1)]) self.assertEqual(curs.fetchall(), [(2, 1)])
def test_failed_commit(self): def test_failed_commit(self):
# Test that we can recover from a failed commit. # Test that we can recover from a failed commit.
# We use a deferred constraint to cause a failure on commit. # We use a deferred constraint to cause a failure on commit.
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED') curs.execute('SET CONSTRAINTS table2__table1_id__fk DEFERRED')
curs.execute('INSERT INTO table2 VALUES (2, 42)') curs.execute('INSERT INTO table2 VALUES (2, 42)')
# The commit should fail, and move the cursor back to READY state # The commit should fail, and move the cursor back to READY state
self.assertEqual(self.conn.status, STATUS_BEGIN) self.assertEqual(self.conn.status, STATUS_BEGIN)
self.assertRaises(psycopg2.IntegrityError, self.conn.commit) self.assertRaises(psycopg2.IntegrityError, self.conn.commit)
self.assertEqual(self.conn.status, STATUS_READY) self.assertEqual(self.conn.status, STATUS_READY)
# The connection should be ready to use for the next transaction: # The connection should be ready to use for the next transaction:
curs.execute('SELECT 1') curs.execute('SELECT 1')
self.assertEqual(curs.fetchone()[0], 1) self.assertEqual(curs.fetchone()[0], 1)
class DeadlockSerializationTests(ConnectingTestCase): class DeadlockSerializationTests(ConnectingTestCase):
@ -103,32 +103,32 @@ class DeadlockSerializationTests(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
curs = self.conn.cursor() with self.conn.cursor() as curs:
# Drop table if it already exists # Drop table if it already exists
try: try:
curs.execute("DROP TABLE table1") curs.execute("DROP TABLE table1")
self.conn.commit() self.conn.commit()
except psycopg2.DatabaseError: except psycopg2.DatabaseError:
self.conn.rollback() self.conn.rollback()
try: try:
curs.execute("DROP TABLE table2") curs.execute("DROP TABLE table2")
self.conn.commit() self.conn.commit()
except psycopg2.DatabaseError: except psycopg2.DatabaseError:
self.conn.rollback() self.conn.rollback()
# Create sample data # Create sample data
curs.execute(""" curs.execute("""
CREATE TABLE table1 ( CREATE TABLE table1 (
id int PRIMARY KEY, id int PRIMARY KEY,
name text) name text)
""") """)
curs.execute("INSERT INTO table1 VALUES (1, 'hello')") curs.execute("INSERT INTO table1 VALUES (1, 'hello')")
curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)") curs.execute("CREATE TABLE table2 (id int PRIMARY KEY)")
self.conn.commit() self.conn.commit()
def tearDown(self): def tearDown(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("DROP TABLE table1") curs.execute("DROP TABLE table1")
curs.execute("DROP TABLE table2") curs.execute("DROP TABLE table2")
self.conn.commit() self.conn.commit()
ConnectingTestCase.tearDown(self) ConnectingTestCase.tearDown(self)
@ -142,11 +142,11 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task1(): def task1():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE")
step1.set() step1.set()
step2.wait() step2.wait()
curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE")
except psycopg2.DatabaseError as exc: except psycopg2.DatabaseError as exc:
self.thread1_error = exc self.thread1_error = exc
step1.set() step1.set()
@ -155,11 +155,11 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task2(): def task2():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
step1.wait() step1.wait()
curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE") curs.execute("LOCK table2 IN ACCESS EXCLUSIVE MODE")
step2.set() step2.set()
curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE") curs.execute("LOCK table1 IN ACCESS EXCLUSIVE MODE")
except psycopg2.DatabaseError as exc: except psycopg2.DatabaseError as exc:
self.thread2_error = exc self.thread2_error = exc
step2.set() step2.set()
@ -190,12 +190,12 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task1(): def task1():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("SELECT name FROM table1 WHERE id = 1") curs.execute("SELECT name FROM table1 WHERE id = 1")
curs.fetchall() curs.fetchall()
step1.set() step1.set()
step2.wait() step2.wait()
curs.execute("UPDATE table1 SET name='task1' WHERE id = 1") curs.execute("UPDATE table1 SET name='task1' WHERE id = 1")
conn.commit() conn.commit()
except psycopg2.DatabaseError as exc: except psycopg2.DatabaseError as exc:
self.thread1_error = exc self.thread1_error = exc
@ -205,9 +205,9 @@ class DeadlockSerializationTests(ConnectingTestCase):
def task2(): def task2():
try: try:
conn = self.connect() conn = self.connect()
curs = conn.cursor() with conn.cursor() as curs:
step1.wait() step1.wait()
curs.execute("UPDATE table1 SET name='task2' WHERE id = 1") curs.execute("UPDATE table1 SET name='task2' WHERE id = 1")
conn.commit() conn.commit()
except psycopg2.DatabaseError as exc: except psycopg2.DatabaseError as exc:
self.thread2_error = exc self.thread2_error = exc
@ -240,11 +240,11 @@ class QueryCancellationTests(ConnectingTestCase):
@skip_before_postgres(8, 2) @skip_before_postgres(8, 2)
def test_statement_timeout(self): def test_statement_timeout(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
# Set a low statement timeout, then sleep for a longer period. # Set a low statement timeout, then sleep for a longer period.
curs.execute('SET statement_timeout TO 10') curs.execute('SET statement_timeout TO 10')
self.assertRaises(psycopg2.extensions.QueryCanceledError, self.assertRaises(psycopg2.extensions.QueryCanceledError,
curs.execute, 'SELECT pg_sleep(50)') curs.execute, 'SELECT pg_sleep(50)')
def test_suite(): def test_suite():

View File

@ -41,9 +41,9 @@ class TypesBasicTests(ConnectingTestCase):
"""Test that all type conversions are working.""" """Test that all type conversions are working."""
def execute(self, *args): def execute(self, *args):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(*args) curs.execute(*args)
return curs.fetchone()[0] return curs.fetchone()[0]
def testQuoting(self): def testQuoting(self):
s = "Quote'this\\! ''ok?''" s = "Quote'this\\! ''ok?''"
@ -156,26 +156,27 @@ class TypesBasicTests(ConnectingTestCase):
def testEmptyArrayRegression(self): def testEmptyArrayRegression(self):
# ticket #42 # ticket #42
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute( curs.execute(
"create table array_test " "create table array_test "
"(id integer, col timestamp without time zone[])") "(id integer, col timestamp without time zone[])")
curs.execute("insert into array_test values (%s, %s)", curs.execute("insert into array_test values (%s, %s)",
(1, [datetime.date(2011, 2, 14)])) (1, [datetime.date(2011, 2, 14)]))
curs.execute("select col from array_test where id = 1") curs.execute("select col from array_test where id = 1")
self.assertEqual(curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)]) self.assertEqual(
curs.fetchone()[0], [datetime.datetime(2011, 2, 14, 0, 0)])
curs.execute("insert into array_test values (%s, %s)", (2, [])) curs.execute("insert into array_test values (%s, %s)", (2, []))
curs.execute("select col from array_test where id = 2") curs.execute("select col from array_test where id = 2")
self.assertEqual(curs.fetchone()[0], []) self.assertEqual(curs.fetchone()[0], [])
@testutils.skip_before_postgres(8, 4) @testutils.skip_before_postgres(8, 4)
def testNestedEmptyArray(self): def testNestedEmptyArray(self):
# issue #788 # issue #788
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select 10 = any(%s::int[])", ([[]], )) curs.execute("select 10 = any(%s::int[])", ([[]], ))
self.assertFalse(curs.fetchone()[0]) self.assertFalse(curs.fetchone()[0])
def testEmptyArrayNoCast(self): def testEmptyArrayNoCast(self):
s = self.execute("SELECT '{}' AS foo") s = self.execute("SELECT '{}' AS foo")
@ -204,86 +205,86 @@ class TypesBasicTests(ConnectingTestCase):
self.failUnlessEqual(ss, r) self.failUnlessEqual(ss, r)
def testArrayMalformed(self): def testArrayMalformed(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
ss = ['', '{', '{}}', '{' * 20 + '}' * 20] ss = ['', '{', '{}}', '{' * 20 + '}' * 20]
for s in ss: for s in ss:
self.assertRaises(psycopg2.DataError, self.assertRaises(psycopg2.DataError,
psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs) psycopg2.extensions.STRINGARRAY, s.encode('utf8'), curs)
def testTextArray(self): def testTextArray(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select '{a,b,c}'::text[]") curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x[0], str)) self.assert_(isinstance(x[0], str))
self.assertEqual(x, ['a', 'b', 'c']) self.assertEqual(x, ['a', 'b', 'c'])
def testUnicodeArray(self): def testUnicodeArray(self):
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.UNICODEARRAY, self.conn) psycopg2.extensions.UNICODEARRAY, self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select '{a,b,c}'::text[]") curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x[0], text_type)) self.assert_(isinstance(x[0], text_type))
self.assertEqual(x, [u'a', u'b', u'c']) self.assertEqual(x, [u'a', u'b', u'c'])
def testBytesArray(self): def testBytesArray(self):
psycopg2.extensions.register_type( psycopg2.extensions.register_type(
psycopg2.extensions.BYTESARRAY, self.conn) psycopg2.extensions.BYTESARRAY, self.conn)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select '{a,b,c}'::text[]") curs.execute("select '{a,b,c}'::text[]")
x = curs.fetchone()[0] x = curs.fetchone()[0]
self.assert_(isinstance(x[0], bytes)) self.assert_(isinstance(x[0], bytes))
self.assertEqual(x, [b'a', b'b', b'c']) self.assertEqual(x, [b'a', b'b', b'c'])
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def testArrayOfNulls(self): def testArrayOfNulls(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute(""" curs.execute("""
create table na ( create table na (
texta text[], texta text[],
inta int[], inta int[],
boola boolean[], boola boolean[],
textaa text[][], textaa text[][],
intaa int[][], intaa int[][],
boolaa boolean[][] boolaa boolean[][]
)""") )""")
curs.execute("insert into na (texta) values (%s)", ([None],)) curs.execute("insert into na (texta) values (%s)", ([None],))
curs.execute("insert into na (texta) values (%s)", (['a', None],)) curs.execute("insert into na (texta) values (%s)", (['a', None],))
curs.execute("insert into na (texta) values (%s)", ([None, None],)) curs.execute("insert into na (texta) values (%s)", ([None, None],))
curs.execute("insert into na (inta) values (%s)", ([None],)) curs.execute("insert into na (inta) values (%s)", ([None],))
curs.execute("insert into na (inta) values (%s)", ([42, None],)) curs.execute("insert into na (inta) values (%s)", ([42, None],))
curs.execute("insert into na (inta) values (%s)", ([None, None],)) curs.execute("insert into na (inta) values (%s)", ([None, None],))
curs.execute("insert into na (boola) values (%s)", ([None],)) curs.execute("insert into na (boola) values (%s)", ([None],))
curs.execute("insert into na (boola) values (%s)", ([True, None],)) curs.execute("insert into na (boola) values (%s)", ([True, None],))
curs.execute("insert into na (boola) values (%s)", ([None, None],)) curs.execute("insert into na (boola) values (%s)", ([None, None],))
curs.execute("insert into na (textaa) values (%s)", ([[None]],)) curs.execute("insert into na (textaa) values (%s)", ([[None]],))
curs.execute("insert into na (textaa) values (%s)", ([['a', None]],)) curs.execute("insert into na (textaa) values (%s)", ([['a', None]],))
curs.execute("insert into na (textaa) values (%s)", ([[None, None]],)) curs.execute("insert into na (textaa) values (%s)", ([[None, None]],))
curs.execute("insert into na (intaa) values (%s)", ([[None]],)) curs.execute("insert into na (intaa) values (%s)", ([[None]],))
curs.execute("insert into na (intaa) values (%s)", ([[42, None]],)) curs.execute("insert into na (intaa) values (%s)", ([[42, None]],))
curs.execute("insert into na (intaa) values (%s)", ([[None, None]],)) curs.execute("insert into na (intaa) values (%s)", ([[None, None]],))
curs.execute("insert into na (boolaa) values (%s)", ([[None]],)) curs.execute("insert into na (boolaa) values (%s)", ([[None]],))
curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],)) curs.execute("insert into na (boolaa) values (%s)", ([[True, None]],))
curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],)) curs.execute("insert into na (boolaa) values (%s)", ([[None, None]],))
@testutils.skip_before_postgres(8, 2) @testutils.skip_before_postgres(8, 2)
def testNestedArrays(self): def testNestedArrays(self):
curs = self.conn.cursor() with self.conn.cursor() as curs:
for a in [ for a in [
[[1]], [[1]],
[[None]], [[None]],
[[None, None, None]], [[None, None, None]],
[[None, None], [1, None]], [[None, None], [1, None]],
[[None, None], [None, None]], [[None, None], [None, None]],
[[[None, None], [None, None]]], [[[None, None], [None, None]]],
]: ]:
curs.execute("select %s::int[]", (a,)) curs.execute("select %s::int[]", (a,))
self.assertEqual(curs.fetchone()[0], a) self.assertEqual(curs.fetchone()[0], a)
@testutils.skip_from_python(3) @testutils.skip_from_python(3)
def testTypeRoundtripBuffer(self): def testTypeRoundtripBuffer(self):

File diff suppressed because it is too large Load Diff

116
tests/test_warnings.py Normal file
View File

@ -0,0 +1,116 @@
import re
import subprocess
import sys
import unittest
import warnings
import psycopg2
from .testconfig import dsn
from .testutils import skip_before_python
class WarningsTest(unittest.TestCase):
@skip_before_python(3)
def test_connection_not_closed(self):
def f():
psycopg2.connect(dsn)
msg = (
"^unclosed connection <connection object at 0x[0-9a-fA-F]+; dsn: '.*', "
"closed: 0>$"
)
with self.assertWarnsRegex(ResourceWarning, msg):
f()
@skip_before_python(3)
def test_cursor_not_closed(self):
def f():
conn = psycopg2.connect(dsn)
try:
conn.cursor()
finally:
conn.close()
msg = (
"^unclosed cursor <cursor object at 0x[0-9a-fA-F]+; closed: 0> for "
"connection <connection object at 0x[0-9a-fA-F]+; dsn: '.*', closed: 0>$"
)
with self.assertWarnsRegex(ResourceWarning, msg):
f()
def test_cursor_factory_returns_non_cursor(self):
def bad_factory(*args, **kwargs):
return object()
def f():
conn = psycopg2.connect(dsn)
try:
conn.cursor(cursor_factory=bad_factory)
finally:
conn.close()
with warnings.catch_warnings(record=True) as cm:
with self.assertRaises(TypeError):
f()
# No warning as no cursor was instantiated.
self.assertEquals(cm, [])
@skip_before_python(3)
def test_broken_close(self):
script = """
import psycopg2
class MyException(Exception):
pass
class MyCurs(psycopg2.extensions.cursor):
def close(self):
raise MyException
def f():
conn = psycopg2.connect(%(dsn)r)
try:
conn.cursor(cursor_factory=MyCurs, scrollable=True)
finally:
conn.close()
f()
""" % {"dsn": dsn}
p = subprocess.Popen(
[sys.executable, "-Walways", "-c", script],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
output, _ = p.communicate()
output = output.decode()
# Normalize line endings.
output = "\n".join(output.splitlines())
self.assertRegex(
output,
re.compile(
r"^Exception ignored in: "
r"<cursor object at 0x[0-9a-fA-F]+; closed: 0>$",
re.M,
),
)
self.assertIn("\n__main__.MyException: \n", output)
self.assertRegex(
output,
re.compile(
r"ResourceWarning: unclosed cursor "
r"<cursor object at 0x[0-9a-fA-F]+; closed: 0> "
r"for connection "
r"<connection object at 0x[0-9a-fA-F]+; dsn: '.*', closed: 0>$",
re.M,
),
)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)
if __name__ == "__main__":
unittest.main()

View File

@ -33,15 +33,15 @@ from .testutils import ConnectingTestCase, skip_before_postgres
class WithTestCase(ConnectingTestCase): class WithTestCase(ConnectingTestCase):
def setUp(self): def setUp(self):
ConnectingTestCase.setUp(self) ConnectingTestCase.setUp(self)
curs = self.conn.cursor() with self.conn.cursor() as curs:
try: try:
curs.execute("delete from test_with") curs.execute("delete from test_with")
self.conn.commit() self.conn.commit()
except psycopg2.ProgrammingError: except psycopg2.ProgrammingError:
# assume table doesn't exist # assume table doesn't exist
self.conn.rollback() self.conn.rollback()
curs.execute("create table test_with (id integer primary key)") curs.execute("create table test_with (id integer primary key)")
self.conn.commit() self.conn.commit()
class WithConnectionTestCase(WithTestCase): class WithConnectionTestCase(WithTestCase):
@ -49,59 +49,59 @@ class WithConnectionTestCase(WithTestCase):
with self.conn as conn: with self.conn as conn:
self.assert_(self.conn is conn) self.assert_(self.conn is conn)
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (1)") curs.execute("insert into test_with values (1)")
self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(1,)]) self.assertEqual(curs.fetchall(), [(1,)])
def test_with_connect_idiom(self): def test_with_connect_idiom(self):
with self.connect() as conn: with self.connect() as conn:
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (2)") curs.execute("insert into test_with values (2)")
self.assertEqual(conn.status, ext.STATUS_BEGIN) self.assertEqual(conn.status, ext.STATUS_BEGIN)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(2,)]) self.assertEqual(curs.fetchall(), [(2,)])
def test_with_error_db(self): def test_with_error_db(self):
def f(): def f():
with self.conn as conn: with self.conn as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values ('a')") curs.execute("insert into test_with values ('a')")
self.assertRaises(psycopg2.DataError, f) self.assertRaises(psycopg2.DataError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
def test_with_error_python(self): def test_with_error_python(self):
def f(): def f():
with self.conn as conn: with self.conn as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (3)") curs.execute("insert into test_with values (3)")
1 / 0 1 / 0
self.assertRaises(ZeroDivisionError, f) self.assertRaises(ZeroDivisionError, f)
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
def test_with_closed(self): def test_with_closed(self):
def f(): def f():
@ -120,15 +120,15 @@ class WithConnectionTestCase(WithTestCase):
super(MyConn, self).commit() super(MyConn, self).commit()
with self.connect(connection_factory=MyConn) as conn: with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (10)") curs.execute("insert into test_with values (10)")
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(commits) self.assert_(commits)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(10,)]) self.assertEqual(curs.fetchall(), [(10,)])
def test_subclass_rollback(self): def test_subclass_rollback(self):
rollbacks = [] rollbacks = []
@ -140,9 +140,9 @@ class WithConnectionTestCase(WithTestCase):
try: try:
with self.connect(connection_factory=MyConn) as conn: with self.connect(connection_factory=MyConn) as conn:
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("insert into test_with values (11)") curs.execute("insert into test_with values (11)")
1 / 0 1 / 0
except ZeroDivisionError: except ZeroDivisionError:
pass pass
else: else:
@ -151,9 +151,9 @@ class WithConnectionTestCase(WithTestCase):
self.assertEqual(conn.status, ext.STATUS_READY) self.assertEqual(conn.status, ext.STATUS_READY)
self.assert_(rollbacks) self.assert_(rollbacks)
curs = conn.cursor() with conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
class WithCursorTestCase(WithTestCase): class WithCursorTestCase(WithTestCase):
@ -168,9 +168,9 @@ class WithCursorTestCase(WithTestCase):
self.assertEqual(self.conn.status, ext.STATUS_READY) self.assertEqual(self.conn.status, ext.STATUS_READY)
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), [(4,)]) self.assertEqual(curs.fetchall(), [(4,)])
def test_with_error(self): def test_with_error(self):
try: try:
@ -185,9 +185,9 @@ class WithCursorTestCase(WithTestCase):
self.assert_(not self.conn.closed) self.assert_(not self.conn.closed)
self.assert_(curs.closed) self.assert_(curs.closed)
curs = self.conn.cursor() with self.conn.cursor() as curs:
curs.execute("select * from test_with") curs.execute("select * from test_with")
self.assertEqual(curs.fetchall(), []) self.assertEqual(curs.fetchall(), [])
def test_subclass(self): def test_subclass(self):
closes = [] closes = []

View File

@ -228,9 +228,9 @@ def skip_if_no_uuid(f):
@wraps(f) @wraps(f)
def skip_if_no_uuid_(self): def skip_if_no_uuid_(self):
try: try:
cur = self.conn.cursor() with self.conn.cursor() as cur:
cur.execute("select typname from pg_type where typname = 'uuid'") cur.execute("select typname from pg_type where typname = 'uuid'")
has = cur.fetchone() has = cur.fetchone()
finally: finally:
self.conn.rollback() self.conn.rollback()
@ -248,15 +248,17 @@ def skip_if_tpc_disabled(f):
@wraps(f) @wraps(f)
def skip_if_tpc_disabled_(self): def skip_if_tpc_disabled_(self):
cnn = self.connect() cnn = self.connect()
cur = cnn.cursor()
try: try:
cur.execute("SHOW max_prepared_transactions;") with cnn.cursor() as cur:
except psycopg2.ProgrammingError: try:
return self.skipTest( cur.execute("SHOW max_prepared_transactions;")
"server too old: two phase transactions not supported.") except psycopg2.ProgrammingError:
else: return self.skipTest(
mtp = int(cur.fetchone()[0]) "server too old: two phase transactions not supported.")
cnn.close() else:
mtp = int(cur.fetchone()[0])
finally:
cnn.close()
if not mtp: if not mtp:
return self.skipTest( return self.skipTest(