diff --git a/psycopg/psycopgmodule.c b/psycopg/psycopgmodule.c index 1b0567b3..23e648d2 100644 --- a/psycopg/psycopgmodule.c +++ b/psycopg/psycopgmodule.c @@ -418,16 +418,20 @@ psyco_encrypt_password(PyObject *self, PyObject *args, PyObject *kwargs) PyObject *password = NULL, *user = NULL; PyObject *scope = Py_None, *algorithm = Py_None; PyObject *res = NULL; + connectionObject *conn = NULL; static char *kwlist[] = {"password", "user", "scope", "algorithm", NULL}; - connectionObject *conn = NULL; - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|OO", kwlist, &password, &user, &scope, &algorithm)) { return NULL; } + /* for ensure_bytes */ + Py_INCREF(user); + Py_INCREF(password); + Py_INCREF(algorithm); + if (scope != Py_None) { if (PyObject_TypeCheck(scope, &cursorType)) { conn = ((cursorObject*)scope)->conn; @@ -437,16 +441,11 @@ psyco_encrypt_password(PyObject *self, PyObject *args, PyObject *kwargs) } else { PyErr_SetString(PyExc_TypeError, - "the scope must be a connection or a cursor"); + "the scope must be a connection or a cursor"); goto exit; } } - /* for ensure_bytes */ - Py_INCREF(user); - Py_INCREF(password); - Py_INCREF(algorithm); - if (!(user = psycopg_ensure_bytes(user))) { goto exit; } if (!(password = psycopg_ensure_bytes(password))) { goto exit; } if (algorithm != Py_None) { @@ -473,7 +472,7 @@ psyco_encrypt_password(PyObject *self, PyObject *args, PyObject *kwargs) goto exit; } - /* TODO: algo = will block: forbid on async/green conn? */ + /* TODO: algo = None will block: forbid on async/green conn? */ encrypted = PQencryptPasswordConn(conn->pgconn, Bytes_AS_STRING(password), Bytes_AS_STRING(user), algorithm != Py_None ? Bytes_AS_STRING(algorithm) : NULL); diff --git a/tests/test_connection.py b/tests/test_connection.py index 7861ab71..13635f1f 100755 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -1482,6 +1482,17 @@ class TestEncryptPassword(ConnectingTestCase): password='psycopg2', user='ashesh', scope=self.conn, algorithm='scram-sha-256') + def test_bad_types(self): + self.assertRaises(TypeError, ext.encrypt_password) + self.assertRaises(TypeError, ext.encrypt_password, + 'password', 42, self.conn, 'md5') + self.assertRaises(TypeError, ext.encrypt_password, + 42, 'user', self.conn, 'md5') + self.assertRaises(TypeError, ext.encrypt_password, + 42, 'user', 'wat', 'abc') + self.assertRaises(TypeError, ext.encrypt_password, + 'password', 'user', 'wat', 42) + class AutocommitTests(ConnectingTestCase): def test_closed(self):