Fixed NamedTupleCursor rownumber during iteration.

The correction is similar to the other one for the other subclasses.

Also added tests for rowcount and rownumber during different fetch styles.
Just in case.
This commit is contained in:
Daniele Varrazzo 2012-02-23 22:58:58 +00:00
parent ebec522a07
commit b8597dc1d3
2 changed files with 40 additions and 8 deletions

View File

@ -310,14 +310,17 @@ class NamedTupleCursor(_cursor):
return [nt(*t) for t in ts]
def __iter__(self):
# Invoking _cursor.__iter__(self) goes to infinite recursion,
# so we do pagination by hand
it = _cursor.__iter__(self)
t = it.next()
nt = self.Record
if nt is None:
nt = self.Record = self._make_nt()
yield nt(*t)
while 1:
recs = self.fetchmany(self.itersize)
if not recs:
return
for rec in recs:
yield rec
yield nt(*it.next())
try:
from collections import namedtuple

View File

@ -222,12 +222,14 @@ class NamedTupleCursorTest(unittest.TestCase):
@skip_if_no_namedtuple
def test_fetchone(self):
curs = self.conn.cursor()
curs.execute("select * from nttest where i = 1")
curs.execute("select * from nttest order by 1")
t = curs.fetchone()
self.assertEqual(t[0], 1)
self.assertEqual(t.i, 1)
self.assertEqual(t[1], 'foo')
self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple
def test_fetchmany_noarg(self):
@ -240,6 +242,8 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple
def test_fetchmany(self):
@ -251,6 +255,8 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple
def test_fetchall(self):
@ -264,6 +270,8 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assertEqual(res[1].s, 'bar')
self.assertEqual(res[2].i, 3)
self.assertEqual(res[2].s, 'baz')
self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple
def test_executemany(self):
@ -281,16 +289,26 @@ class NamedTupleCursorTest(unittest.TestCase):
curs = self.conn.cursor()
curs.execute("select * from nttest order by 1")
i = iter(curs)
self.assertEqual(curs.rownumber, 0)
t = i.next()
self.assertEqual(t.i, 1)
self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3)
t = i.next()
self.assertEqual(t.i, 2)
self.assertEqual(t.s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
t = i.next()
self.assertEqual(t.i, 3)
self.assertEqual(t.s, 'baz')
self.assertRaises(StopIteration, i.next)
self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3)
def test_error_message(self):
try:
@ -415,6 +433,17 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005))
self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099))
@skip_if_no_namedtuple
@skip_before_postgres(8, 0)
def test_named_rownumber(self):
curs = self.conn.cursor('tmp')
# Only checking for dataset < itersize:
# see CursorTests.test_iter_named_cursor_rownumber
curs.itersize = 4
curs.execute("""select * from generate_series(1,3)""")
for i, t in enumerate(curs):
self.assertEqual(i + 1, curs.rownumber)
def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__)