diff --git a/doc/src/cursor.rst b/doc/src/cursor.rst index 73bb5375..9df65865 100644 --- a/doc/src/cursor.rst +++ b/doc/src/cursor.rst @@ -201,12 +201,17 @@ The ``cursor`` class Call a stored database procedure with the given name. The sequence of parameters must contain one entry for each argument that the procedure - expects. The result of the call is returned as modified copy of the - input sequence. Input parameters are left untouched, output and - input/output parameters replaced with possibly new values. - - The procedure may also provide a result set as output. This must then - be made available through the standard |fetch*|_ methods. + expects. Overloaded procedures are supported. Named parameters can be + used with a PostgreSQL 9.0+ client by supplying the sequence of + parameters as a Dict. + + This function is, at present, not DBAPI-compliant. The return value is + supposed to consist of the sequence of parameters with modified output + and input/output parameters. In future versions, the DBAPI-compliant + return value may be implemented, but for now the function returns None. + + The procedure may provide a result set as output. This is then made + available through the standard |fetch*|_ methods. .. method:: mogrify(operation [, parameters]) diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 3dbdbdc4..8214d359 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -1028,6 +1028,7 @@ psyco_curs_callproc(cursorObject *self, PyObject *args) int using_dict; #if PG_VERSION_HEX >= 0x090000 PyObject *pname = NULL; + PyObject *spname = NULL; PyObject *bpname = NULL; PyObject *pnames = NULL; char *cpname = NULL; @@ -1055,7 +1056,7 @@ psyco_curs_callproc(cursorObject *self, PyObject *args) using_dict = nparameters > 0 && PyDict_Check(parameters); - /* A Dict is complicated. The parameter names go into the query */ + /* a Dict is complicated; the parameter names go into the query */ if (using_dict) { #if PG_VERSION_HEX >= 0x090000 if (!(pnames = PyDict_Keys(parameters))) { @@ -1073,33 +1074,46 @@ psyco_curs_callproc(cursorObject *self, PyObject *args) memset(scpnames, 0, sizeof(char *) * nparameters); - /* Each parameter has to be processed. It's a few steps. */ + /* each parameter has to be processed; it's a few steps. */ for (i = 0; i < nparameters; i++) { + /* all errors are RuntimeErrors as they should never occur */ + if (!(pname = PyList_GetItem(pnames, i))) { PyErr_SetString(PyExc_RuntimeError, "built-in 'values' did not return List!"); goto exit; } - if (!(bpname = psycopg_ensure_bytes(pname))) { - PyErr_SetString(PyExc_TypeError, - "argument 2 must have only string keys if Dict"); + if (!(spname = PyObject_Str(pname))) { + PyErr_SetString(PyExc_RuntimeError, + "built-in 'str' failed!"); + goto exit; + } + + /* this is the only function here that returns a new reference */ + if (!(bpname = psycopg_ensure_bytes(spname))) { + PyErr_SetString(PyExc_RuntimeError, + "failed to get Bytes from text!"); goto exit; } if (!(cpname = Bytes_AsString(bpname))) { + Py_XDECREF(bpname); PyErr_SetString(PyExc_RuntimeError, - "failed to get Bytes from String!"); + "failed to get cstr from Bytes!"); goto exit; } if (!(scpnames[i] = PQescapeIdentifier(self->conn->pgconn, cpname, strlen(cpname)))) { + Py_XDECREF(bpname); PyErr_SetString(PyExc_RuntimeError, "libpq failed to escape identifier!"); goto exit; } + Py_XDECREF(bpname); + sl += strlen(scpnames[i]); } @@ -1153,11 +1167,11 @@ psyco_curs_callproc(cursorObject *self, PyObject *args) if (0 <= _psyco_curs_execute(self, operation, parameters, self->conn->async, 0)) { - /* In the dict case, the parameters are already a new reference */ - if (!using_dict) { - Py_INCREF(parameters); - } - res = parameters; + if (using_dict) { + Py_DECREF(parameters); + } + /* return None from this until it's DBAPI compliant... */ + res = Py_None; } exit: diff --git a/tests/test_cursor.py b/tests/test_cursor.py index 970cc37d..00d19dfb 100755 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -490,6 +490,43 @@ class CursorTests(ConnectingTestCase): cur = self.conn.cursor() self.assertRaises(TypeError, cur.callproc, 'lower', 42) + # It would be inappropriate to test callproc's named parameters in the + # DBAPI2.0 test section because they are a psycopg2 extension. + @skip_before_postgres(9, 0) + def test_callproc_dict(self): + # This parameter name tests for injection and quote escaping + paramname = ''' + Robert'); drop table "students" -- + '''.strip() + escaped_paramname = '"%s"' % paramname.replace('"', '""') + procname = 'pg_temp.randall' + + cur = self.conn.cursor() + + # Set up the temporary function + cur.execute(''' + CREATE FUNCTION %s(%s INT) + RETURNS INT AS + 'SELECT $1 * $1' + LANGUAGE SQL + ''' % (procname, escaped_paramname)); + + # Make sure callproc works right + cur.callproc(procname, { paramname: 2 }) + self.assertEquals(cur.fetchone()[0], 4) + + # Make sure callproc fails right + failing_cases = [ + ({ paramname: 2, 'foo': 'bar' }, psycopg2.ProgrammingError), + ({ paramname: '2' }, psycopg2.ProgrammingError), + ({ paramname: 'two' }, psycopg2.ProgrammingError), + ({ 'bjørn': 2 }, psycopg2.ProgrammingError), + ({ 3: 2 }, psycopg2.ProgrammingError), + ({ self: 2 }, psycopg2.ProgrammingError), + ] + for parameter_sequence, exception in failing_cases: + self.assertRaises(exception, cur.callproc, procname, parameter_sequence) + self.conn.rollback() def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__) diff --git a/tests/test_psycopg2_dbapi20.py b/tests/test_psycopg2_dbapi20.py index 744d3224..28ea6690 100755 --- a/tests/test_psycopg2_dbapi20.py +++ b/tests/test_psycopg2_dbapi20.py @@ -36,6 +36,29 @@ class Psycopg2Tests(dbapi20.DatabaseAPI20Test): connect_kw_args = {'dsn': dsn} lower_func = 'lower' # For stored procedure test + def test_callproc(self): + # Until DBAPI 2.0 compliance, callproc should return None or it's just + # misleading. Therefore, we will skip the return value test for + # callproc and only perform the fetch test. + # + # For what it's worth, the DBAPI2.0 test_callproc doesn't actually + # test for DBAPI2.0 compliance! It doesn't check for modified OUT and + # IN/OUT parameters in the return values! + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur,'callproc'): + cur.callproc(self.lower_func,('FOO',)) + r = cur.fetchall() + self.assertEqual(len(r),1,'callproc produced no result set') + self.assertEqual(len(r[0]),1, + 'callproc produced invalid result set' + ) + self.assertEqual(r[0][0],'foo', + 'callproc produced invalid results' + ) + finally: + con.close() def test_setoutputsize(self): # psycopg2's setoutputsize() is a no-op