diff --git a/tests/test_copy.py b/tests/test_copy.py index b6da4b1b..9026abc5 100755 --- a/tests/test_copy.py +++ b/tests/test_copy.py @@ -78,7 +78,7 @@ class CopyTests(unittest.TestCase): curs = self.conn.cursor() curs.execute(''' CREATE TEMPORARY TABLE tcopy ( - id int PRIMARY KEY, + id serial PRIMARY KEY, data text )''') @@ -180,6 +180,39 @@ class CopyTests(unittest.TestCase): f.seek(0) self.assertEqual(f.readline().rstrip(), about) + @skip_if_no_iobase + def test_copy_expert_textiobase(self): + self.conn.set_client_encoding('latin1') + self._create_temp_table() # the above call closed the xn + + if sys.version_info[0] < 3: + abin = ''.join(map(chr, range(32, 127) + range(160, 256))) + abin = abin.decode('latin1') + about = abin.replace('\\', '\\\\') + + else: + abin = bytes(range(32, 127) + range(160, 256)).decode('latin1') + about = abin.replace('\\', '\\\\') + + import io + f = io.StringIO() + f.write(about) + f.seek(0) + + curs = self.conn.cursor() + psycopg2.extensions.register_type( + psycopg2.extensions.UNICODE, curs) + + curs.copy_expert('COPY tcopy (data) FROM STDIN', f) + curs.execute("select data from tcopy;") + self.assertEqual(curs.fetchone()[0], abin) + + f = io.StringIO() + curs.copy_expert('COPY tcopy (data) TO STDOUT', f) + f.seek(0) + self.assertEqual(f.readline().rstrip(), about) + + def _copy_from(self, curs, nrecs, srec, copykw): f = StringIO() for i, c in izip(xrange(nrecs), cycle(string.ascii_letters)):