diff --git a/lib/extras.py b/lib/extras.py index aa7bc877..870b5ca7 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -310,14 +310,17 @@ class NamedTupleCursor(_cursor): return [nt(*t) for t in ts] def __iter__(self): - # Invoking _cursor.__iter__(self) goes to infinite recursion, - # so we do pagination by hand + it = _cursor.__iter__(self) + t = it.next() + + nt = self.Record + if nt is None: + nt = self.Record = self._make_nt() + + yield nt(*t) + while 1: - recs = self.fetchmany(self.itersize) - if not recs: - return - for rec in recs: - yield rec + yield nt(*it.next()) try: from collections import namedtuple diff --git a/tests/test_extras_dictcursor.py b/tests/test_extras_dictcursor.py index 3bdb3ba3..dd746379 100755 --- a/tests/test_extras_dictcursor.py +++ b/tests/test_extras_dictcursor.py @@ -222,12 +222,14 @@ class NamedTupleCursorTest(unittest.TestCase): @skip_if_no_namedtuple def test_fetchone(self): curs = self.conn.cursor() - curs.execute("select * from nttest where i = 1") + curs.execute("select * from nttest order by 1") t = curs.fetchone() self.assertEqual(t[0], 1) self.assertEqual(t.i, 1) self.assertEqual(t[1], 'foo') self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_fetchmany_noarg(self): @@ -240,6 +242,8 @@ class NamedTupleCursorTest(unittest.TestCase): self.assertEqual(res[0].s, 'foo') self.assertEqual(res[1].i, 2) self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_fetchmany(self): @@ -251,6 +255,8 @@ class NamedTupleCursorTest(unittest.TestCase): self.assertEqual(res[0].s, 'foo') self.assertEqual(res[1].i, 2) self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_fetchall(self): @@ -264,6 +270,8 @@ class NamedTupleCursorTest(unittest.TestCase): self.assertEqual(res[1].s, 'bar') self.assertEqual(res[2].i, 3) self.assertEqual(res[2].s, 'baz') + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_executemany(self): @@ -281,16 +289,26 @@ class NamedTupleCursorTest(unittest.TestCase): curs = self.conn.cursor() curs.execute("select * from nttest order by 1") i = iter(curs) + self.assertEqual(curs.rownumber, 0) + t = i.next() self.assertEqual(t.i, 1) self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) + t = i.next() self.assertEqual(t.i, 2) self.assertEqual(t.s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) + t = i.next() self.assertEqual(t.i, 3) self.assertEqual(t.s, 'baz') self.assertRaises(StopIteration, i.next) + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) def test_error_message(self): try: @@ -415,6 +433,17 @@ class NamedTupleCursorTest(unittest.TestCase): self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) + @skip_if_no_namedtuple + @skip_before_postgres(8, 0) + def test_named_rownumber(self): + curs = self.conn.cursor('tmp') + # Only checking for dataset < itersize: + # see CursorTests.test_iter_named_cursor_rownumber + curs.itersize = 4 + curs.execute("""select * from generate_series(1,3)""") + for i, t in enumerate(curs): + self.assertEqual(i + 1, curs.rownumber) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)