Fixed code flow in encrypt_password()

Fixed several shortcomings highlighted in #576 and not fixed as
requested.

Also fixed broken behaviour of ignoring the algorithm if the connection
is missing.
This commit is contained in:
Daniele Varrazzo 2018-05-20 19:18:42 +01:00
parent 0161d54dbb
commit a3063900ee
2 changed files with 73 additions and 74 deletions

View File

@ -409,100 +409,99 @@ psyco_libpq_version(PyObject *self)
/* encrypt_password - Prepare the encrypted password form */ /* encrypt_password - Prepare the encrypted password form */
#define psyco_encrypt_password_doc \ #define psyco_encrypt_password_doc \
"encrypt_password(password, user, [conn_or_curs], [algorithm]) -- Prepares the encrypted form of a PostgreSQL password.\n\n" "encrypt_password(password, user, [scope], [algorithm]) -- Prepares the encrypted form of a PostgreSQL password.\n\n"
static PyObject * static PyObject *
psyco_encrypt_password(PyObject *self, PyObject *args) psyco_encrypt_password(PyObject *self, PyObject *args, PyObject *kwargs)
{ {
char *encrypted = NULL; char *encrypted = NULL;
PyObject *password = NULL, *user = NULL;
PyObject *scope = Py_None, *algorithm = Py_None;
PyObject *res = NULL;
PyObject *obj = NULL, static char *kwlist[] = {"password", "user", "scope", "algorithm", NULL};
*res = Py_None,
*password = NULL,
*user = NULL,
*algorithm = NULL;
connectionObject *conn = NULL; connectionObject *conn = NULL;
if (!PyArg_ParseTuple(args, "OO|OO", if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO|OO", kwlist,
&password, &user, &obj, &algorithm)) { &password, &user, &scope, &algorithm)) {
return NULL; return NULL;
} }
if (obj != NULL && obj != Py_None) { if (scope != Py_None) {
if (PyObject_TypeCheck(obj, &cursorType)) { if (PyObject_TypeCheck(scope, &cursorType)) {
conn = ((cursorObject*)obj)->conn; conn = ((cursorObject*)scope)->conn;
} }
else if (PyObject_TypeCheck(obj, &connectionType)) { else if (PyObject_TypeCheck(scope, &connectionType)) {
conn = (connectionObject*)obj; conn = (connectionObject*)scope;
} }
else { else {
PyErr_SetString(PyExc_TypeError, PyErr_SetString(PyExc_TypeError,
"argument 3 must be a connection or a cursor"); "the scope must be a connection or a cursor");
return NULL; goto exit;
} }
} }
/* for ensure_bytes */ /* for ensure_bytes */
Py_INCREF(user); Py_INCREF(user);
Py_INCREF(password); Py_INCREF(password);
if (algorithm) {
Py_INCREF(algorithm); Py_INCREF(algorithm);
}
if (!(user = psycopg_ensure_bytes(user))) { goto exit; } if (!(user = psycopg_ensure_bytes(user))) { goto exit; }
if (!(password = psycopg_ensure_bytes(password))) { goto exit; } if (!(password = psycopg_ensure_bytes(password))) { goto exit; }
if (algorithm && !(algorithm = psycopg_ensure_bytes(algorithm))) { if (algorithm != Py_None) {
if (!(algorithm = psycopg_ensure_bytes(algorithm))) {
goto exit; goto exit;
} }
/* Use the libpq API 'PQencryptPassword', when no connection object is
available, or the algorithm is set to as 'md5', or the database server
version < 10 */
if (conn == NULL || conn->server_version < 100000 ||
(algorithm != NULL && algorithm != Py_None &&
strcmp(Bytes_AS_STRING(algorithm), "md5") == 0)) {
encrypted = PQencryptPassword(Bytes_AS_STRING(password),
Bytes_AS_STRING(user));
if (encrypted != NULL) {
res = Text_FromUTF8(encrypted);
PQfreemem(encrypted);
}
goto exit;
} }
/* If we have to encrypt md5 we can use the libpq < 10 API */
if (algorithm != Py_None &&
strcmp(Bytes_AS_STRING(algorithm), "md5") == 0) {
encrypted = PQencryptPassword(
Bytes_AS_STRING(password), Bytes_AS_STRING(user));
}
/* If the algorithm is not md5 we have to use the API available from
* libpq 10. */
else {
#if PG_VERSION_NUM >= 100000 #if PG_VERSION_NUM >= 100000
encrypted = PQencryptPasswordConn(conn->pgconn, Bytes_AS_STRING(password), if (!conn) {
Bytes_AS_STRING(user), PyErr_SetString(ProgrammingError,
algorithm ? Bytes_AS_STRING(algorithm) : NULL); "password encryption (other than 'md5' algorithm)"
" requires a connection or cursor");
if (!encrypted) {
const char *msg = PQerrorMessage(conn->pgconn);
if (msg && *msg) {
PyErr_Format(ProgrammingError, "%s", msg);
res = NULL;
goto exit; goto exit;
} }
/* TODO: algo = will block: forbid on async/green conn? */
encrypted = PQencryptPasswordConn(conn->pgconn,
Bytes_AS_STRING(password), Bytes_AS_STRING(user),
algorithm != Py_None ? Bytes_AS_STRING(algorithm) : NULL);
#else
PyErr_SetString(NotSupportedError,
"password encryption (other than 'md5' algorithm)"
" requires libpq 10");
goto exit;
#endif
}
if (encrypted) {
res = Text_FromUTF8(encrypted);
} }
else { else {
res = Text_FromUTF8(encrypted); const char *msg = PQerrorMessage(conn->pgconn);
PQfreemem(encrypted); PyErr_Format(ProgrammingError,
"password encryption failed: %s", msg ? msg : "no reason given");
goto exit;
} }
#else
PyErr_SetString(
NotSupportedError,
"Password encryption (other than 'md5' algorithm) is not supported for the server version >= 10 in libpq < 10"
);
res = NULL;
#endif
exit: exit:
if (encrypted) {
PQfreemem(encrypted);
}
Py_XDECREF(user); Py_XDECREF(user);
Py_XDECREF(password); Py_XDECREF(password);
if (algorithm) {
Py_XDECREF(algorithm); Py_XDECREF(algorithm);
}
return res; return res;
} }

View File

@ -1397,6 +1397,18 @@ class TransactionControlTests(ConnectingTestCase):
cur.execute("SHOW default_transaction_read_only;") cur.execute("SHOW default_transaction_read_only;")
self.assertEqual(cur.fetchone()[0], 'off') self.assertEqual(cur.fetchone()[0], 'off')
def test_idempotence_check(self):
self.conn.autocommit = False
self.conn.readonly = True
self.conn.autocommit = True
self.conn.readonly = True
cur = self.conn.cursor()
cur.execute("SHOW transaction_read_only")
self.assertEqual(cur.fetchone()[0], 'on')
class TestEncryptPassword(ConnectingTestCase):
@skip_before_postgres(10) @skip_before_postgres(10)
def test_encrypt_password_post_9_6(self): def test_encrypt_password_post_9_6(self):
cur = self.conn.cursor() cur = self.conn.cursor()
@ -1441,20 +1453,8 @@ class TransactionControlTests(ConnectingTestCase):
# Encryption algorithm will be ignored for postgres version < 10, it # Encryption algorithm will be ignored for postgres version < 10, it
# will always use MD5. # will always use MD5.
self.assertEqual( self.assertRaises(psycopg2.ProgrammingError,
ext.encrypt_password('psycopg2', 'ashesh', self.conn, 'abc'), ext.encrypt_password, 'psycopg2', 'ashesh', self.conn, 'abc')
'md594839d658c28a357126f105b9cb14cfc'
)
def test_idempotence_check(self):
self.conn.autocommit = False
self.conn.readonly = True
self.conn.autocommit = True
self.conn.readonly = True
cur = self.conn.cursor()
cur.execute("SHOW transaction_read_only")
self.assertEqual(cur.fetchone()[0], 'on')
class AutocommitTests(ConnectingTestCase): class AutocommitTests(ConnectingTestCase):