Fixed sql stuff in Py3

This commit is contained in:
Daniele Varrazzo 2017-01-01 05:59:21 +01:00
parent c4a67fc1c1
commit 600416aafc
3 changed files with 38 additions and 17 deletions

View File

@ -24,7 +24,9 @@
# License for more details. # License for more details.
import re import re
import sys
import collections import collections
from psycopg2 import extensions as ext from psycopg2 import extensions as ext
@ -146,19 +148,23 @@ class Literal(Composible):
return "sql.Literal(%r)" % (self._wrapped,) return "sql.Literal(%r)" % (self._wrapped,)
def as_string(self, conn_or_curs): def as_string(self, conn_or_curs):
# is it a connection or cursor?
if isinstance(conn_or_curs, ext.connection):
conn = conn_or_curs
elif isinstance(conn_or_curs, ext.cursor):
conn = conn_or_curs.connection
else:
raise TypeError("conn_or_curs must be a connection or a cursor")
a = ext.adapt(self._wrapped) a = ext.adapt(self._wrapped)
if hasattr(a, 'prepare'): if hasattr(a, 'prepare'):
# is it a connection or cursor?
if isinstance(conn_or_curs, ext.connection):
conn = conn_or_curs
elif isinstance(conn_or_curs, ext.cursor):
conn = conn_or_curs.connection
else:
raise TypeError("conn_or_curs must be a connection or a cursor")
a.prepare(conn) a.prepare(conn)
return a.getquoted() rv = a.getquoted()
if sys.version_info[0] >= 3 and isinstance(rv, bytes):
rv = rv.decode(ext.encodings[conn.encoding])
return rv
def __mul__(self, n): def __mul__(self, n):
return Composed([self] * n) return Composed([self] * n)

View File

@ -293,7 +293,8 @@ static PyObject *_psyco_curs_validate_sql_basic(
) )
{ {
PyObject *rv = NULL; PyObject *rv = NULL;
int comp; PyObject *comp = NULL;
int iscomp;
/* Performs very basic validation on an incoming SQL string. /* Performs very basic validation on an incoming SQL string.
Returns a new reference to a str instance on success; NULL on failure, Returns a new reference to a str instance on success; NULL on failure,
@ -313,20 +314,36 @@ static PyObject *_psyco_curs_validate_sql_basic(
else if (PyUnicode_Check(sql)) { else if (PyUnicode_Check(sql)) {
if (!(rv = conn_encode(self->conn, sql))) { goto exit; } if (!(rv = conn_encode(self->conn, sql))) { goto exit; }
} }
else if (0 != (comp = _curs_is_composible(sql))) { else if (0 != (iscomp = _curs_is_composible(sql))) {
if (comp < 0) { goto exit; } if (iscomp < 0) { goto exit; }
if (!(rv = PyObject_CallMethod(sql, "as_string", "O", self->conn))) { if (!(comp = PyObject_CallMethod(sql, "as_string", "O", self->conn))) {
goto exit;
}
if (Bytes_Check(comp)) {
rv = comp;
comp = NULL;
}
else if (PyUnicode_Check(comp)) {
if (!(rv = conn_encode(self->conn, comp))) { goto exit; }
}
else {
PyErr_Format(PyExc_TypeError,
"as_string() should return a string: got %s instead",
Py_TYPE(comp)->tp_name);
goto exit; goto exit;
} }
} }
else { else {
/* the is not unicode or string, raise an error */ /* the is not unicode or string, raise an error */
PyErr_SetString(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"argument 1 must be a string or unicode object"); "argument 1 must be a string or unicode object: got %s instead",
Py_TYPE(sql)->tp_name);
goto exit; goto exit;
} }
exit: exit:
Py_XDECREF(comp);
return rv; return rv;
} }

View File

@ -107,7 +107,6 @@ class IdentifierTests(ConnectingTestCase):
def test_init(self): def test_init(self):
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier)) self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier)) self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier))
self.assert_(isinstance(sql.Identifier(b'foo'), sql.Identifier))
self.assertRaises(TypeError, sql.Identifier, 10) self.assertRaises(TypeError, sql.Identifier, 10)
self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31)) self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
@ -155,7 +154,6 @@ class SQLTests(ConnectingTestCase):
def test_init(self): def test_init(self):
self.assert_(isinstance(sql.SQL('foo'), sql.SQL)) self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL)) self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL))
self.assert_(isinstance(sql.SQL(b'foo'), sql.SQL))
self.assertRaises(TypeError, sql.SQL, 10) self.assertRaises(TypeError, sql.SQL, 10)
self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31)) self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))