mirror of
https://github.com/psycopg/psycopg2.git
synced 2024-11-30 04:33:45 +03:00
Added some COPY tests.
This commit is contained in:
parent
7e0dcfdda5
commit
b114e25c31
|
@ -27,6 +27,7 @@ import test_transaction
|
||||||
import types_basic
|
import types_basic
|
||||||
import types_extras
|
import types_extras
|
||||||
import test_lobject
|
import test_lobject
|
||||||
|
import test_copy
|
||||||
import test_async
|
import test_async
|
||||||
|
|
||||||
def test_suite():
|
def test_suite():
|
||||||
|
@ -41,6 +42,7 @@ def test_suite():
|
||||||
suite.addTest(types_basic.test_suite())
|
suite.addTest(types_basic.test_suite())
|
||||||
suite.addTest(types_extras.test_suite())
|
suite.addTest(types_extras.test_suite())
|
||||||
suite.addTest(test_lobject.test_suite())
|
suite.addTest(test_lobject.test_suite())
|
||||||
|
suite.addTest(test_copy.test_suite())
|
||||||
suite.addTest(test_async.test_suite())
|
suite.addTest(test_async.test_suite())
|
||||||
return suite
|
return suite
|
||||||
|
|
||||||
|
|
104
tests/test_copy.py
Normal file
104
tests/test_copy.py
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
#!/usr/bin/env python
|
||||||
|
import os
|
||||||
|
import string
|
||||||
|
import unittest
|
||||||
|
from cStringIO import StringIO
|
||||||
|
from itertools import cycle, izip
|
||||||
|
|
||||||
|
import psycopg2
|
||||||
|
import psycopg2.extensions
|
||||||
|
import tests
|
||||||
|
|
||||||
|
|
||||||
|
class MinimalRead(object):
|
||||||
|
"""A file wrapper exposing the minimal interface to copy from."""
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
|
||||||
|
def read(self, size):
|
||||||
|
return self.f.read(size)
|
||||||
|
|
||||||
|
def readline(self):
|
||||||
|
return self.f.readline()
|
||||||
|
|
||||||
|
class MinimalWrite(object):
|
||||||
|
"""A file wrapper exposing the minimal interface to copy to."""
|
||||||
|
def __init__(self, f):
|
||||||
|
self.f = f
|
||||||
|
|
||||||
|
def write(self, data):
|
||||||
|
return self.f.write(data)
|
||||||
|
|
||||||
|
class CopyTests(unittest.TestCase):
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.conn = psycopg2.connect(tests.dsn)
|
||||||
|
curs = self.conn.cursor()
|
||||||
|
curs.execute('''
|
||||||
|
CREATE TEMPORARY TABLE tcopy (
|
||||||
|
id int PRIMARY KEY,
|
||||||
|
data text
|
||||||
|
)''')
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
self.conn.close()
|
||||||
|
|
||||||
|
def test_copy_from(self):
|
||||||
|
curs = self.conn.cursor()
|
||||||
|
try:
|
||||||
|
self._copy_from(curs, nrecs=1024, srec=10*1024, copykw={})
|
||||||
|
finally:
|
||||||
|
curs.close()
|
||||||
|
|
||||||
|
def test_copy_from_insane_size(self):
|
||||||
|
# Trying to trigger a "would block" error
|
||||||
|
curs = self.conn.cursor()
|
||||||
|
try:
|
||||||
|
self._copy_from(curs, nrecs=10*1024, srec=10*1024, copykw={'size': 20*1024*1024})
|
||||||
|
finally:
|
||||||
|
curs.close()
|
||||||
|
|
||||||
|
def test_copy_to(self):
|
||||||
|
curs = self.conn.cursor()
|
||||||
|
try:
|
||||||
|
self._copy_from(curs, nrecs=1024, srec=10*1024, copykw={})
|
||||||
|
self._copy_to(curs, srec=10*1024)
|
||||||
|
finally:
|
||||||
|
curs.close()
|
||||||
|
|
||||||
|
def _copy_from(self, curs, nrecs, srec, copykw):
|
||||||
|
f = StringIO()
|
||||||
|
for i, c in izip(xrange(nrecs), cycle(string.letters)):
|
||||||
|
l = c * srec
|
||||||
|
f.write("%s\t%s\n" % (i,l))
|
||||||
|
|
||||||
|
f.seek(0)
|
||||||
|
curs.copy_from(MinimalRead(f), "tcopy", **copykw)
|
||||||
|
|
||||||
|
curs.execute("select count(*) from tcopy")
|
||||||
|
self.assertEqual(nrecs, curs.fetchone()[0])
|
||||||
|
|
||||||
|
curs.execute("select data from tcopy where id < %s order by id",
|
||||||
|
(len(string.letters),))
|
||||||
|
for i, (l,) in enumerate(curs):
|
||||||
|
self.assertEqual(l, string.letters[i] * srec)
|
||||||
|
|
||||||
|
def _copy_to(self, curs, srec):
|
||||||
|
f = StringIO()
|
||||||
|
curs.copy_to(MinimalWrite(f), "tcopy")
|
||||||
|
|
||||||
|
f.seek(0)
|
||||||
|
ntests = 0
|
||||||
|
for line in f:
|
||||||
|
n, s = line.split()
|
||||||
|
if int(n) < len(string.letters):
|
||||||
|
self.assertEqual(s, string.letters[int(n)] * srec)
|
||||||
|
ntests += 1
|
||||||
|
|
||||||
|
self.assertEqual(ntests, len(string.letters))
|
||||||
|
|
||||||
|
def test_suite():
|
||||||
|
return unittest.TestLoader().loadTestsFromName(__name__)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user