mirror of
https://github.com/psycopg/psycopg2.git
synced 2025-02-25 21:20:32 +03:00
COPY sends unicode to a file if it derives from io.TextIoBase
Fixes ticket #36.
This commit is contained in:
parent
d40b394c50
commit
b544354db2
1
NEWS-2.3
1
NEWS-2.3
|
@ -9,6 +9,7 @@ What's new in psycopg 2.3.3
|
|||
- Added 'b' and 't' mode to large objects: write can deal with both bytes
|
||||
strings and unicode; read can return either bytes strings or decoded
|
||||
unicode.
|
||||
- COPY sends Unicode data to files implementing io.TextIOBase.
|
||||
- The build script refuses to guess values if pg_config is not found.
|
||||
- Improved PostgreSQL-Python encodings mapping. Added a few
|
||||
missing encodings: EUC_CN, EUC_JIS_2004, ISO885910, ISO885916,
|
||||
|
|
|
@ -1163,7 +1163,9 @@ static int
|
|||
_pq_copy_out_v3(cursorObject *curs)
|
||||
{
|
||||
PyObject *tmp = NULL, *func;
|
||||
PyObject *obj = NULL;
|
||||
int ret = -1;
|
||||
int is_text;
|
||||
|
||||
char *buffer;
|
||||
Py_ssize_t len;
|
||||
|
@ -1173,14 +1175,28 @@ _pq_copy_out_v3(cursorObject *curs)
|
|||
goto exit;
|
||||
}
|
||||
|
||||
/* if the file is text we must pass it unicode. */
|
||||
if (-1 == (is_text = psycopg_is_text_file(curs->copyfile))) {
|
||||
goto exit;
|
||||
}
|
||||
|
||||
while (1) {
|
||||
Py_BEGIN_ALLOW_THREADS;
|
||||
len = PQgetCopyData(curs->conn->pgconn, &buffer, 0);
|
||||
Py_END_ALLOW_THREADS;
|
||||
|
||||
if (len > 0 && buffer) {
|
||||
tmp = PyObject_CallFunction(func, "s#", buffer, len);
|
||||
if (is_text) {
|
||||
obj = PyUnicode_Decode(buffer, len, curs->conn->codec, NULL);
|
||||
} else {
|
||||
obj = Bytes_FromStringAndSize(buffer, len);
|
||||
}
|
||||
|
||||
PQfreemem(buffer);
|
||||
if (!obj) { goto exit; }
|
||||
tmp = PyObject_CallFunctionObjArgs(func, obj, NULL);
|
||||
Py_DECREF(obj);
|
||||
|
||||
if (tmp == NULL) {
|
||||
goto exit;
|
||||
} else {
|
||||
|
|
|
@ -123,6 +123,7 @@ HIDDEN char *psycopg_escape_string(PyObject *conn,
|
|||
HIDDEN char *psycopg_strdup(const char *from, Py_ssize_t len);
|
||||
HIDDEN PyObject * psycopg_ensure_bytes(PyObject *obj);
|
||||
HIDDEN PyObject * psycopg_ensure_text(PyObject *obj);
|
||||
HIDDEN int psycopg_is_text_file(PyObject *f);
|
||||
|
||||
/* Exceptions docstrings */
|
||||
#define Error_doc \
|
||||
|
|
|
@ -149,3 +149,43 @@ psycopg_ensure_text(PyObject *obj)
|
|||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
/* Check if a file derives from TextIOBase.
|
||||
*
|
||||
* Return 1 if it does, else 0, -1 on errors.
|
||||
*/
|
||||
int
|
||||
psycopg_is_text_file(PyObject *f)
|
||||
{
|
||||
/* NULL before any call.
|
||||
* then io.TextIOBase if exists, else None. */
|
||||
static PyObject *base;
|
||||
|
||||
/* Try to import os.TextIOBase */
|
||||
if (NULL == base) {
|
||||
PyObject *m;
|
||||
Dprintf("psycopg_is_text_file: importing io.TextIOBase");
|
||||
if (!(m = PyImport_ImportModule("io"))) {
|
||||
Dprintf("psycopg_is_text_file: io module not found");
|
||||
PyErr_Clear();
|
||||
Py_INCREF(Py_None);
|
||||
base = Py_None;
|
||||
}
|
||||
else {
|
||||
if (!(base = PyObject_GetAttrString(m, "TextIOBase"))) {
|
||||
Dprintf("psycopg_is_text_file: io.TextIOBase not found");
|
||||
PyErr_Clear();
|
||||
Py_INCREF(Py_None);
|
||||
base = Py_None;
|
||||
}
|
||||
}
|
||||
Py_XDECREF(m);
|
||||
}
|
||||
|
||||
if (base != Py_None) {
|
||||
return PyObject_IsInstance(f, base);
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -23,8 +23,9 @@
|
|||
# License for more details.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import string
|
||||
from testutils import unittest, decorate_all_tests
|
||||
from testutils import unittest, decorate_all_tests, skip_if_no_iobase
|
||||
from cStringIO import StringIO
|
||||
from itertools import cycle, izip
|
||||
|
||||
|
@ -42,7 +43,12 @@ def skip_if_green(f):
|
|||
return skip_if_green_
|
||||
|
||||
|
||||
class MinimalRead(object):
|
||||
if sys.version_info[0] < 3:
|
||||
_base = object
|
||||
else:
|
||||
from io import TextIOBase as _base
|
||||
|
||||
class MinimalRead(_base):
|
||||
"""A file wrapper exposing the minimal interface to copy from."""
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
|
@ -53,7 +59,7 @@ class MinimalRead(object):
|
|||
def readline(self):
|
||||
return self.f.readline()
|
||||
|
||||
class MinimalWrite(object):
|
||||
class MinimalWrite(_base):
|
||||
"""A file wrapper exposing the minimal interface to copy to."""
|
||||
def __init__(self, f):
|
||||
self.f = f
|
||||
|
@ -66,6 +72,9 @@ class CopyTests(unittest.TestCase):
|
|||
|
||||
def setUp(self):
|
||||
self.conn = psycopg2.connect(dsn)
|
||||
self._create_temp_table()
|
||||
|
||||
def _create_temp_table(self):
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('''
|
||||
CREATE TEMPORARY TABLE tcopy (
|
||||
|
@ -126,6 +135,51 @@ class CopyTests(unittest.TestCase):
|
|||
finally:
|
||||
curs.close()
|
||||
|
||||
@skip_if_no_iobase
|
||||
def test_copy_text(self):
|
||||
self.conn.set_client_encoding('latin1')
|
||||
self._create_temp_table() # the above call closed the xn
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
abin = ''.join(map(chr, range(32, 127) + range(160, 256)))
|
||||
about = abin.decode('latin1').replace('\\', '\\\\')
|
||||
|
||||
else:
|
||||
abin = bytes(range(32, 127) + range(160, 256)).decode('latin1')
|
||||
about = abin.replace('\\', '\\\\')
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('insert into tcopy values (%s, %s)',
|
||||
(42, abin))
|
||||
|
||||
import io
|
||||
f = io.StringIO()
|
||||
curs.copy_to(f, 'tcopy', columns=('data',))
|
||||
f.seek(0)
|
||||
self.assertEqual(f.readline().rstrip(), about)
|
||||
|
||||
@skip_if_no_iobase
|
||||
def test_copy_bytes(self):
|
||||
self.conn.set_client_encoding('latin1')
|
||||
self._create_temp_table() # the above call closed the xn
|
||||
|
||||
if sys.version_info[0] < 3:
|
||||
abin = ''.join(map(chr, range(32, 127) + range(160, 255)))
|
||||
about = abin.replace('\\', '\\\\')
|
||||
else:
|
||||
abin = bytes(range(32, 127) + range(160, 255)).decode('latin1')
|
||||
about = abin.replace('\\', '\\\\').encode('latin1')
|
||||
|
||||
curs = self.conn.cursor()
|
||||
curs.execute('insert into tcopy values (%s, %s)',
|
||||
(42, abin))
|
||||
|
||||
import io
|
||||
f = io.BytesIO()
|
||||
curs.copy_to(f, 'tcopy', columns=('data',))
|
||||
f.seek(0)
|
||||
self.assertEqual(f.readline().rstrip(), about)
|
||||
|
||||
def _copy_from(self, curs, nrecs, srec, copykw):
|
||||
f = StringIO()
|
||||
for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)):
|
||||
|
|
|
@ -151,6 +151,19 @@ def skip_if_tpc_disabled(f):
|
|||
return skip_if_tpc_disabled_
|
||||
|
||||
|
||||
def skip_if_no_iobase(f):
|
||||
"""Skip a test if io.TextIOBase is not available."""
|
||||
def skip_if_no_iobase_(self):
|
||||
try:
|
||||
from io import TextIOBase
|
||||
except ImportError:
|
||||
return self.skipTest("io.TextIOBase not found.")
|
||||
else:
|
||||
return f(self)
|
||||
|
||||
return skip_if_no_iobase_
|
||||
|
||||
|
||||
def skip_on_python2(f):
|
||||
"""Skip a test on Python 3 and following."""
|
||||
def skip_on_python2_(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user