COPY sends unicode to a file if it derives from io.TextIoBase

Fixes ticket #36.
This commit is contained in:
Daniele Varrazzo 2011-02-05 15:12:37 +01:00
parent d40b394c50
commit b544354db2
6 changed files with 129 additions and 4 deletions

View File

@ -9,6 +9,7 @@ What's new in psycopg 2.3.3
- Added 'b' and 't' mode to large objects: write can deal with both bytes
strings and unicode; read can return either bytes strings or decoded
unicode.
- COPY sends Unicode data to files implementing io.TextIOBase.
- The build script refuses to guess values if pg_config is not found.
- Improved PostgreSQL-Python encodings mapping. Added a few
missing encodings: EUC_CN, EUC_JIS_2004, ISO885910, ISO885916,

View File

@ -1163,7 +1163,9 @@ static int
_pq_copy_out_v3(cursorObject *curs)
{
PyObject *tmp = NULL, *func;
PyObject *obj = NULL;
int ret = -1;
int is_text;
char *buffer;
Py_ssize_t len;
@ -1173,14 +1175,28 @@ _pq_copy_out_v3(cursorObject *curs)
goto exit;
}
/* if the file is text we must pass it unicode. */
if (-1 == (is_text = psycopg_is_text_file(curs->copyfile))) {
goto exit;
}
while (1) {
Py_BEGIN_ALLOW_THREADS;
len = PQgetCopyData(curs->conn->pgconn, &buffer, 0);
Py_END_ALLOW_THREADS;
if (len > 0 && buffer) {
tmp = PyObject_CallFunction(func, "s#", buffer, len);
if (is_text) {
obj = PyUnicode_Decode(buffer, len, curs->conn->codec, NULL);
} else {
obj = Bytes_FromStringAndSize(buffer, len);
}
PQfreemem(buffer);
if (!obj) { goto exit; }
tmp = PyObject_CallFunctionObjArgs(func, obj, NULL);
Py_DECREF(obj);
if (tmp == NULL) {
goto exit;
} else {

View File

@ -123,6 +123,7 @@ HIDDEN char *psycopg_escape_string(PyObject *conn,
HIDDEN char *psycopg_strdup(const char *from, Py_ssize_t len);
HIDDEN PyObject * psycopg_ensure_bytes(PyObject *obj);
HIDDEN PyObject * psycopg_ensure_text(PyObject *obj);
HIDDEN int psycopg_is_text_file(PyObject *f);
/* Exceptions docstrings */
#define Error_doc \

View File

@ -149,3 +149,43 @@ psycopg_ensure_text(PyObject *obj)
}
#endif
}
/* Check if a file derives from TextIOBase.
*
* Return 1 if it does, else 0, -1 on errors.
*/
int
psycopg_is_text_file(PyObject *f)
{
/* NULL before any call.
* then io.TextIOBase if exists, else None. */
static PyObject *base;
/* Try to import os.TextIOBase */
if (NULL == base) {
PyObject *m;
Dprintf("psycopg_is_text_file: importing io.TextIOBase");
if (!(m = PyImport_ImportModule("io"))) {
Dprintf("psycopg_is_text_file: io module not found");
PyErr_Clear();
Py_INCREF(Py_None);
base = Py_None;
}
else {
if (!(base = PyObject_GetAttrString(m, "TextIOBase"))) {
Dprintf("psycopg_is_text_file: io.TextIOBase not found");
PyErr_Clear();
Py_INCREF(Py_None);
base = Py_None;
}
}
Py_XDECREF(m);
}
if (base != Py_None) {
return PyObject_IsInstance(f, base);
} else {
return 0;
}
}

View File

@ -23,8 +23,9 @@
# License for more details.
import os
import sys
import string
from testutils import unittest, decorate_all_tests
from testutils import unittest, decorate_all_tests, skip_if_no_iobase
from cStringIO import StringIO
from itertools import cycle, izip
@ -42,7 +43,12 @@ def skip_if_green(f):
return skip_if_green_
class MinimalRead(object):
if sys.version_info[0] < 3:
_base = object
else:
from io import TextIOBase as _base
class MinimalRead(_base):
"""A file wrapper exposing the minimal interface to copy from."""
def __init__(self, f):
self.f = f
@ -53,7 +59,7 @@ class MinimalRead(object):
def readline(self):
return self.f.readline()
class MinimalWrite(object):
class MinimalWrite(_base):
"""A file wrapper exposing the minimal interface to copy to."""
def __init__(self, f):
self.f = f
@ -66,6 +72,9 @@ class CopyTests(unittest.TestCase):
def setUp(self):
self.conn = psycopg2.connect(dsn)
self._create_temp_table()
def _create_temp_table(self):
curs = self.conn.cursor()
curs.execute('''
CREATE TEMPORARY TABLE tcopy (
@ -126,6 +135,51 @@ class CopyTests(unittest.TestCase):
finally:
curs.close()
@skip_if_no_iobase
def test_copy_text(self):
self.conn.set_client_encoding('latin1')
self._create_temp_table() # the above call closed the xn
if sys.version_info[0] < 3:
abin = ''.join(map(chr, range(32, 127) + range(160, 256)))
about = abin.decode('latin1').replace('\\', '\\\\')
else:
abin = bytes(range(32, 127) + range(160, 256)).decode('latin1')
about = abin.replace('\\', '\\\\')
curs = self.conn.cursor()
curs.execute('insert into tcopy values (%s, %s)',
(42, abin))
import io
f = io.StringIO()
curs.copy_to(f, 'tcopy', columns=('data',))
f.seek(0)
self.assertEqual(f.readline().rstrip(), about)
@skip_if_no_iobase
def test_copy_bytes(self):
self.conn.set_client_encoding('latin1')
self._create_temp_table() # the above call closed the xn
if sys.version_info[0] < 3:
abin = ''.join(map(chr, range(32, 127) + range(160, 255)))
about = abin.replace('\\', '\\\\')
else:
abin = bytes(range(32, 127) + range(160, 255)).decode('latin1')
about = abin.replace('\\', '\\\\').encode('latin1')
curs = self.conn.cursor()
curs.execute('insert into tcopy values (%s, %s)',
(42, abin))
import io
f = io.BytesIO()
curs.copy_to(f, 'tcopy', columns=('data',))
f.seek(0)
self.assertEqual(f.readline().rstrip(), about)
def _copy_from(self, curs, nrecs, srec, copykw):
f = StringIO()
for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)):

View File

@ -151,6 +151,19 @@ def skip_if_tpc_disabled(f):
return skip_if_tpc_disabled_
def skip_if_no_iobase(f):
"""Skip a test if io.TextIOBase is not available."""
def skip_if_no_iobase_(self):
try:
from io import TextIOBase
except ImportError:
return self.skipTest("io.TextIOBase not found.")
else:
return f(self)
return skip_if_no_iobase_
def skip_on_python2(f):
"""Skip a test on Python 3 and following."""
def skip_on_python2_(self):