From 600416aafc58ac6aaaa4aeb3629518c7c1f788e0 Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Sun, 1 Jan 2017 05:59:21 +0100 Subject: [PATCH] Fixed sql stuff in Py3 --- lib/sql.py | 24 +++++++++++++++--------- psycopg/cursor_type.c | 29 +++++++++++++++++++++++------ tests/test_sql.py | 2 -- 3 files changed, 38 insertions(+), 17 deletions(-) diff --git a/lib/sql.py b/lib/sql.py index 0bd1e31b..21ec9cc7 100644 --- a/lib/sql.py +++ b/lib/sql.py @@ -24,7 +24,9 @@ # License for more details. import re +import sys import collections + from psycopg2 import extensions as ext @@ -146,19 +148,23 @@ class Literal(Composible): return "sql.Literal(%r)" % (self._wrapped,) def as_string(self, conn_or_curs): + # is it a connection or cursor? + if isinstance(conn_or_curs, ext.connection): + conn = conn_or_curs + elif isinstance(conn_or_curs, ext.cursor): + conn = conn_or_curs.connection + else: + raise TypeError("conn_or_curs must be a connection or a cursor") + a = ext.adapt(self._wrapped) if hasattr(a, 'prepare'): - # is it a connection or cursor? - if isinstance(conn_or_curs, ext.connection): - conn = conn_or_curs - elif isinstance(conn_or_curs, ext.cursor): - conn = conn_or_curs.connection - else: - raise TypeError("conn_or_curs must be a connection or a cursor") - a.prepare(conn) - return a.getquoted() + rv = a.getquoted() + if sys.version_info[0] >= 3 and isinstance(rv, bytes): + rv = rv.decode(ext.encodings[conn.encoding]) + + return rv def __mul__(self, n): return Composed([self] * n) diff --git a/psycopg/cursor_type.c b/psycopg/cursor_type.c index 485c69ec..adb3d0f3 100644 --- a/psycopg/cursor_type.c +++ b/psycopg/cursor_type.c @@ -293,7 +293,8 @@ static PyObject *_psyco_curs_validate_sql_basic( ) { PyObject *rv = NULL; - int comp; + PyObject *comp = NULL; + int iscomp; /* Performs very basic validation on an incoming SQL string. Returns a new reference to a str instance on success; NULL on failure, @@ -313,20 +314,36 @@ static PyObject *_psyco_curs_validate_sql_basic( else if (PyUnicode_Check(sql)) { if (!(rv = conn_encode(self->conn, sql))) { goto exit; } } - else if (0 != (comp = _curs_is_composible(sql))) { - if (comp < 0) { goto exit; } - if (!(rv = PyObject_CallMethod(sql, "as_string", "O", self->conn))) { + else if (0 != (iscomp = _curs_is_composible(sql))) { + if (iscomp < 0) { goto exit; } + if (!(comp = PyObject_CallMethod(sql, "as_string", "O", self->conn))) { + goto exit; + } + + if (Bytes_Check(comp)) { + rv = comp; + comp = NULL; + } + else if (PyUnicode_Check(comp)) { + if (!(rv = conn_encode(self->conn, comp))) { goto exit; } + } + else { + PyErr_Format(PyExc_TypeError, + "as_string() should return a string: got %s instead", + Py_TYPE(comp)->tp_name); goto exit; } } else { /* the is not unicode or string, raise an error */ - PyErr_SetString(PyExc_TypeError, - "argument 1 must be a string or unicode object"); + PyErr_Format(PyExc_TypeError, + "argument 1 must be a string or unicode object: got %s instead", + Py_TYPE(sql)->tp_name); goto exit; } exit: + Py_XDECREF(comp); return rv; } diff --git a/tests/test_sql.py b/tests/test_sql.py index 510b545f..4d930cd4 100755 --- a/tests/test_sql.py +++ b/tests/test_sql.py @@ -107,7 +107,6 @@ class IdentifierTests(ConnectingTestCase): def test_init(self): self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier)) self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier)) - self.assert_(isinstance(sql.Identifier(b'foo'), sql.Identifier)) self.assertRaises(TypeError, sql.Identifier, 10) self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31)) @@ -155,7 +154,6 @@ class SQLTests(ConnectingTestCase): def test_init(self): self.assert_(isinstance(sql.SQL('foo'), sql.SQL)) self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL)) - self.assert_(isinstance(sql.SQL(b'foo'), sql.SQL)) self.assertRaises(TypeError, sql.SQL, 10) self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))