mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-25 18:33:44 +03:00
Fixed sql stuff in Py3
This commit is contained in:
parent
c4a67fc1c1
commit
600416aafc
24
lib/sql.py
24
lib/sql.py
|
@ -24,7 +24,9 @@
|
||||||
# License for more details.
|
# License for more details.
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
from psycopg2 import extensions as ext
|
from psycopg2 import extensions as ext
|
||||||
|
|
||||||
|
|
||||||
|
@ -146,19 +148,23 @@ class Literal(Composible):
|
||||||
return "sql.Literal(%r)" % (self._wrapped,)
|
return "sql.Literal(%r)" % (self._wrapped,)
|
||||||
|
|
||||||
def as_string(self, conn_or_curs):
|
def as_string(self, conn_or_curs):
|
||||||
|
# is it a connection or cursor?
|
||||||
|
if isinstance(conn_or_curs, ext.connection):
|
||||||
|
conn = conn_or_curs
|
||||||
|
elif isinstance(conn_or_curs, ext.cursor):
|
||||||
|
conn = conn_or_curs.connection
|
||||||
|
else:
|
||||||
|
raise TypeError("conn_or_curs must be a connection or a cursor")
|
||||||
|
|
||||||
a = ext.adapt(self._wrapped)
|
a = ext.adapt(self._wrapped)
|
||||||
if hasattr(a, 'prepare'):
|
if hasattr(a, 'prepare'):
|
||||||
# is it a connection or cursor?
|
|
||||||
if isinstance(conn_or_curs, ext.connection):
|
|
||||||
conn = conn_or_curs
|
|
||||||
elif isinstance(conn_or_curs, ext.cursor):
|
|
||||||
conn = conn_or_curs.connection
|
|
||||||
else:
|
|
||||||
raise TypeError("conn_or_curs must be a connection or a cursor")
|
|
||||||
|
|
||||||
a.prepare(conn)
|
a.prepare(conn)
|
||||||
|
|
||||||
return a.getquoted()
|
rv = a.getquoted()
|
||||||
|
if sys.version_info[0] >= 3 and isinstance(rv, bytes):
|
||||||
|
rv = rv.decode(ext.encodings[conn.encoding])
|
||||||
|
|
||||||
|
return rv
|
||||||
|
|
||||||
def __mul__(self, n):
|
def __mul__(self, n):
|
||||||
return Composed([self] * n)
|
return Composed([self] * n)
|
||||||
|
|
|
@ -293,7 +293,8 @@ static PyObject *_psyco_curs_validate_sql_basic(
|
||||||
)
|
)
|
||||||
{
|
{
|
||||||
PyObject *rv = NULL;
|
PyObject *rv = NULL;
|
||||||
int comp;
|
PyObject *comp = NULL;
|
||||||
|
int iscomp;
|
||||||
|
|
||||||
/* Performs very basic validation on an incoming SQL string.
|
/* Performs very basic validation on an incoming SQL string.
|
||||||
Returns a new reference to a str instance on success; NULL on failure,
|
Returns a new reference to a str instance on success; NULL on failure,
|
||||||
|
@ -313,20 +314,36 @@ static PyObject *_psyco_curs_validate_sql_basic(
|
||||||
else if (PyUnicode_Check(sql)) {
|
else if (PyUnicode_Check(sql)) {
|
||||||
if (!(rv = conn_encode(self->conn, sql))) { goto exit; }
|
if (!(rv = conn_encode(self->conn, sql))) { goto exit; }
|
||||||
}
|
}
|
||||||
else if (0 != (comp = _curs_is_composible(sql))) {
|
else if (0 != (iscomp = _curs_is_composible(sql))) {
|
||||||
if (comp < 0) { goto exit; }
|
if (iscomp < 0) { goto exit; }
|
||||||
if (!(rv = PyObject_CallMethod(sql, "as_string", "O", self->conn))) {
|
if (!(comp = PyObject_CallMethod(sql, "as_string", "O", self->conn))) {
|
||||||
|
goto exit;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (Bytes_Check(comp)) {
|
||||||
|
rv = comp;
|
||||||
|
comp = NULL;
|
||||||
|
}
|
||||||
|
else if (PyUnicode_Check(comp)) {
|
||||||
|
if (!(rv = conn_encode(self->conn, comp))) { goto exit; }
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
PyErr_Format(PyExc_TypeError,
|
||||||
|
"as_string() should return a string: got %s instead",
|
||||||
|
Py_TYPE(comp)->tp_name);
|
||||||
goto exit;
|
goto exit;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
/* the is not unicode or string, raise an error */
|
/* the is not unicode or string, raise an error */
|
||||||
PyErr_SetString(PyExc_TypeError,
|
PyErr_Format(PyExc_TypeError,
|
||||||
"argument 1 must be a string or unicode object");
|
"argument 1 must be a string or unicode object: got %s instead",
|
||||||
|
Py_TYPE(sql)->tp_name);
|
||||||
goto exit;
|
goto exit;
|
||||||
}
|
}
|
||||||
|
|
||||||
exit:
|
exit:
|
||||||
|
Py_XDECREF(comp);
|
||||||
return rv;
|
return rv;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -107,7 +107,6 @@ class IdentifierTests(ConnectingTestCase):
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
|
self.assert_(isinstance(sql.Identifier('foo'), sql.Identifier))
|
||||||
self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier))
|
self.assert_(isinstance(sql.Identifier(u'foo'), sql.Identifier))
|
||||||
self.assert_(isinstance(sql.Identifier(b'foo'), sql.Identifier))
|
|
||||||
self.assertRaises(TypeError, sql.Identifier, 10)
|
self.assertRaises(TypeError, sql.Identifier, 10)
|
||||||
self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
|
self.assertRaises(TypeError, sql.Identifier, dt.date(2016, 12, 31))
|
||||||
|
|
||||||
|
@ -155,7 +154,6 @@ class SQLTests(ConnectingTestCase):
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
|
self.assert_(isinstance(sql.SQL('foo'), sql.SQL))
|
||||||
self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL))
|
self.assert_(isinstance(sql.SQL(u'foo'), sql.SQL))
|
||||||
self.assert_(isinstance(sql.SQL(b'foo'), sql.SQL))
|
|
||||||
self.assertRaises(TypeError, sql.SQL, 10)
|
self.assertRaises(TypeError, sql.SQL, 10)
|
||||||
self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))
|
self.assertRaises(TypeError, sql.SQL, dt.date(2016, 12, 31))
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user