diff --git a/psycopg/adapter_qstring.c b/psycopg/adapter_qstring.c index 2e3ab0ae..1e256cf0 100644 --- a/psycopg/adapter_qstring.c +++ b/psycopg/adapter_qstring.c @@ -36,28 +36,43 @@ static const char *default_encoding = "latin1"; /* qstring_quote - do the quote process on plain and unicode strings */ +const char * +_qstring_get_encoding(qstringObject *self) +{ + /* if the wrapped object is an unicode object we can encode it to match + conn->encoding but if the encoding is not specified we don't know what + to do and we raise an exception */ + if (self->conn) { + return self->conn->codec; + } + else { + return self->encoding ? self->encoding : default_encoding; + } +} + static PyObject * qstring_quote(qstringObject *self) { PyObject *str = NULL; char *s, *buffer = NULL; Py_ssize_t len, qlen; - const char *encoding = default_encoding; + const char *encoding; PyObject *rv = NULL; - /* if the wrapped object is an unicode object we can encode it to match - conn->encoding but if the encoding is not specified we don't know what - to do and we raise an exception */ - if (self->conn) { - encoding = self->conn->codec; - } - + encoding = _qstring_get_encoding(self); Dprintf("qstring_quote: encoding to %s", encoding); - if (PyUnicode_Check(self->wrapped) && encoding) { - str = PyUnicode_AsEncodedString(self->wrapped, encoding, NULL); - Dprintf("qstring_quote: got encoded object at %p", str); - if (str == NULL) goto exit; + if (PyUnicode_Check(self->wrapped)) { + if (encoding) { + str = PyUnicode_AsEncodedString(self->wrapped, encoding, NULL); + Dprintf("qstring_quote: got encoded object at %p", str); + if (str == NULL) goto exit; + } + else { + PyErr_SetString(PyExc_TypeError, + "missing encoding to encode unicode object"); + goto exit; + } } #if PY_MAJOR_VERSION < 3 @@ -150,15 +165,34 @@ qstring_conform(qstringObject *self, PyObject *args) static PyObject * qstring_get_encoding(qstringObject *self) { - const char *encoding = default_encoding; - - if (self->conn) { - encoding = self->conn->codec; - } - + const char *encoding; + encoding = _qstring_get_encoding(self); return Text_FromUTF8(encoding); } +static int +qstring_set_encoding(qstringObject *self, PyObject *pyenc) +{ + int rv = -1; + const char *tmp; + char *cenc; + + /* get a C copy of the encoding (which may come from unicode) */ + Py_INCREF(pyenc); + if (!(pyenc = psycopg_ensure_bytes(pyenc))) { goto exit; } + if (!(tmp = Bytes_AsString(pyenc))) { goto exit; } + if (0 > psycopg_strdup(&cenc, tmp, 0)) { goto exit; } + + Dprintf("qstring_set_encoding: encoding set to %s", cenc); + PyMem_Free((void *)self->encoding); + self->encoding = cenc; + rv = 0; + +exit: + Py_XDECREF(pyenc); + return rv; +} + /** the QuotedString object **/ /* object member list */ @@ -183,7 +217,7 @@ static PyMethodDef qstringObject_methods[] = { static PyGetSetDef qstringObject_getsets[] = { { "encoding", (getter)qstring_get_encoding, - (setter)NULL, + (setter)qstring_set_encoding, "current encoding of the adapter" }, {NULL} }; @@ -216,6 +250,7 @@ qstring_dealloc(PyObject* obj) Py_CLEAR(self->wrapped); Py_CLEAR(self->buffer); Py_CLEAR(self->conn); + PyMem_Free((void *)self->encoding); Dprintf("qstring_dealloc: deleted qstring object at %p, refcnt = " FORMAT_CODE_PY_SSIZE_T, diff --git a/psycopg/adapter_qstring.h b/psycopg/adapter_qstring.h index b7b086f3..8abdc5f2 100644 --- a/psycopg/adapter_qstring.h +++ b/psycopg/adapter_qstring.h @@ -39,6 +39,9 @@ typedef struct { PyObject *buffer; connectionObject *conn; + + const char *encoding; + } qstringObject; #ifdef __cplusplus diff --git a/tests/test_types_basic.py b/tests/test_types_basic.py index 4923d820..baa80c01 100755 --- a/tests/test_types_basic.py +++ b/tests/test_types_basic.py @@ -95,11 +95,11 @@ class TypesBasicTests(ConnectingTestCase): except ValueError: return self.skipTest("inf not available on this platform") s = self.execute("SELECT %s AS foo", (float("inf"),)) - self.failUnless(str(s) == "inf", "wrong float quoting: " + str(s)) + self.failUnless(str(s) == "inf", "wrong float quoting: " + str(s)) self.failUnless(type(s) == float, "wrong float conversion: " + repr(s)) s = self.execute("SELECT %s AS foo", (float("-inf"),)) - self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s)) + self.failUnless(str(s) == "-inf", "wrong float quoting: " + str(s)) def testBinary(self): if sys.version_info[0] < 3: @@ -344,6 +344,43 @@ class TypesBasicTests(ConnectingTestCase): self.assertEqual(a, [2,4,'nada']) +class TestStringAdapter(ConnectingTestCase): + def test_encoding_default(self): + from psycopg2.extensions import adapt + a = adapt("hello") + self.assertEqual(a.encoding, 'latin1') + self.assertEqual(a.getquoted(), "'hello'") + + egrave = u'\xe8' + self.assertEqual(adapt(egrave).getquoted(), "'\xe8'") + + def test_encoding_error(self): + from psycopg2.extensions import adapt + snowman = u"\u2603" + a = adapt(snowman) + self.assertRaises(UnicodeEncodeError, a.getquoted) + + def test_set_encoding(self): + from psycopg2.extensions import adapt + snowman = u"\u2603" + a = adapt(snowman) + a.encoding = 'utf8' + self.assertEqual(a.encoding, 'utf8') + self.assertEqual(a.getquoted(), "'\xe2\x98\x83'") + + def test_connection_wins_anyway(self): + from psycopg2.extensions import adapt + snowman = u"\u2603" + a = adapt(snowman) + a.encoding = 'latin9' + + self.conn.set_client_encoding('utf8') + a.prepare(self.conn) + + self.assertEqual(a.encoding, 'utf_8') + self.assertEqual(a.getquoted(), "'\xe2\x98\x83'") + + class AdaptSubclassTest(unittest.TestCase): def test_adapt_subtype(self): from psycopg2.extensions import adapt @@ -364,8 +401,8 @@ class AdaptSubclassTest(unittest.TestCase): try: self.assertEqual(b('b'), adapt(C()).getquoted()) finally: - del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] - del psycopg2.extensions.adapters[B, psycopg2.extensions.ISQLQuote] + del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] + del psycopg2.extensions.adapters[B, psycopg2.extensions.ISQLQuote] @testutils.skip_from_python(3) def test_no_mro_no_joy(self): @@ -378,8 +415,7 @@ class AdaptSubclassTest(unittest.TestCase): try: self.assertRaises(psycopg2.ProgrammingError, adapt, B()) finally: - del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] - + del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] @testutils.skip_before_python(3) def test_adapt_subtype_3(self): @@ -392,7 +428,7 @@ class AdaptSubclassTest(unittest.TestCase): try: self.assertEqual(b("a"), adapt(B()).getquoted()) finally: - del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] + del psycopg2.extensions.adapters[A, psycopg2.extensions.ISQLQuote] class ByteaParserTest(unittest.TestCase): @@ -480,6 +516,7 @@ class ByteaParserTest(unittest.TestCase): self.assertEqual(rv, tgt) + def skip_if_cant_cast(f): @wraps(f) def skip_if_cant_cast_(self, *args, **kwargs): @@ -499,4 +536,3 @@ def test_suite(): if __name__ == "__main__": unittest.main() -