Proper unicode handling in quote_ident.

This commit is contained in:
Oleksandr Shulgin 2015-10-15 11:52:18 +02:00
parent 9295bce154
commit 89bb6b0711
2 changed files with 41 additions and 10 deletions

View File

@ -166,17 +166,25 @@ exit:
} }
#define psyco_quote_ident_doc "quote_ident(str, conn_or_curs) -> str" #define psyco_quote_ident_doc \
"quote_ident(str, conn_or_curs) -> str -- wrapper around PQescapeIdentifier\n\n" \
":Parameters:\n" \
" * `str`: A bytes or unicode object\n" \
" * `conn_or_curs`: A connection or cursor, required"
static PyObject * static PyObject *
psyco_quote_ident(PyObject *self, PyObject *args) psyco_quote_ident(PyObject *self, PyObject *args, PyObject *kwargs)
{ {
const char *str = NULL; #if PG_VERSION_NUM >= 90000
char *quoted; PyObject *ident = NULL, *obj = NULL, *result = NULL;
PyObject *obj, *result;
connectionObject *conn; connectionObject *conn;
const char *str;
char *quoted = NULL;
if (!PyArg_ParseTuple(args, "sO", &str, &obj)) return NULL; static char *kwlist[] = {"ident", "scope", NULL};
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "OO", kwlist, &ident, &obj)) {
return NULL;
}
if (PyObject_TypeCheck(obj, &cursorType)) { if (PyObject_TypeCheck(obj, &cursorType)) {
conn = ((cursorObject*)obj)->conn; conn = ((cursorObject*)obj)->conn;
@ -190,15 +198,27 @@ psyco_quote_ident(PyObject *self, PyObject *args)
return NULL; return NULL;
} }
Py_INCREF(ident); /* for ensure_bytes */
if (!(ident = psycopg_ensure_bytes(ident))) { goto exit; }
str = Bytes_AS_STRING(ident);
quoted = PQescapeIdentifier(conn->pgconn, str, strlen(str)); quoted = PQescapeIdentifier(conn->pgconn, str, strlen(str));
if (!quoted) { if (!quoted) {
PyErr_NoMemory(); PyErr_NoMemory();
return NULL; goto exit;
} }
result = conn_text_from_chars(conn, quoted); result = conn_text_from_chars(conn, quoted);
exit:
PQfreemem(quoted); PQfreemem(quoted);
Py_XDECREF(ident);
return result; return result;
#else
PyErr_SetString(NotSupportedError, "PQescapeIdentifier not available in libpq < 9.0");
return NULL;
#endif
} }
/** type registration **/ /** type registration **/
@ -802,10 +822,10 @@ static PyMethodDef psycopgMethods[] = {
METH_VARARGS|METH_KEYWORDS, psyco_connect_doc}, METH_VARARGS|METH_KEYWORDS, psyco_connect_doc},
{"parse_dsn", (PyCFunction)psyco_parse_dsn, {"parse_dsn", (PyCFunction)psyco_parse_dsn,
METH_VARARGS|METH_KEYWORDS, psyco_parse_dsn_doc}, METH_VARARGS|METH_KEYWORDS, psyco_parse_dsn_doc},
{"quote_ident", (PyCFunction)psyco_quote_ident,
METH_VARARGS|METH_KEYWORDS, psyco_quote_ident_doc},
{"adapt", (PyCFunction)psyco_microprotocols_adapt, {"adapt", (PyCFunction)psyco_microprotocols_adapt,
METH_VARARGS, psyco_microprotocols_adapt_doc}, METH_VARARGS, psyco_microprotocols_adapt_doc},
{"quote_ident", (PyCFunction)psyco_quote_ident,
METH_VARARGS, psyco_quote_ident_doc},
{"register_type", (PyCFunction)psyco_register_type, {"register_type", (PyCFunction)psyco_register_type,
METH_VARARGS, psyco_register_type_doc}, METH_VARARGS, psyco_register_type_doc},

View File

@ -23,7 +23,7 @@
# License for more details. # License for more details.
import sys import sys
from testutils import unittest, ConnectingTestCase from testutils import unittest, ConnectingTestCase, skip_before_libpq
import psycopg2 import psycopg2
import psycopg2.extensions import psycopg2.extensions
@ -166,11 +166,22 @@ class TestQuotedString(ConnectingTestCase):
class TestQuotedIdentifier(ConnectingTestCase): class TestQuotedIdentifier(ConnectingTestCase):
@skip_before_libpq(9, 0)
def test_identifier(self): def test_identifier(self):
from psycopg2.extensions import quote_ident from psycopg2.extensions import quote_ident
self.assertEqual(quote_ident('blah-blah', self.conn), '"blah-blah"') self.assertEqual(quote_ident('blah-blah', self.conn), '"blah-blah"')
self.assertEqual(quote_ident('quote"inside', self.conn), '"quote""inside"') self.assertEqual(quote_ident('quote"inside', self.conn), '"quote""inside"')
@skip_before_libpq(9, 0)
def test_unicode_ident(self):
from psycopg2.extensions import quote_ident
snowman = u"\u2603"
quoted = '"' + snowman + '"'
if sys.version_info[0] < 3:
self.assertEqual(quote_ident(snowman, self.conn), quoted.encode('utf8'))
else:
self.assertEqual(quote_ident(snowman, self.conn), quoted)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)