From b8597dc1d369d69788a446f65d1f8741d09e064a Mon Sep 17 00:00:00 2001 From: Daniele Varrazzo Date: Thu, 23 Feb 2012 22:58:58 +0000 Subject: [PATCH] Fixed NamedTupleCursor rownumber during iteration. The correction is similar to the other one for the other subclasses. Also added tests for rowcount and rownumber during different fetch styles. Just in case. --- lib/extras.py | 17 ++++++++++------- tests/test_extras_dictcursor.py | 31 ++++++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 8 deletions(-) diff --git a/lib/extras.py b/lib/extras.py index aa7bc877..870b5ca7 100644 --- a/lib/extras.py +++ b/lib/extras.py @@ -310,14 +310,17 @@ class NamedTupleCursor(_cursor): return [nt(*t) for t in ts] def __iter__(self): - # Invoking _cursor.__iter__(self) goes to infinite recursion, - # so we do pagination by hand + it = _cursor.__iter__(self) + t = it.next() + + nt = self.Record + if nt is None: + nt = self.Record = self._make_nt() + + yield nt(*t) + while 1: - recs = self.fetchmany(self.itersize) - if not recs: - return - for rec in recs: - yield rec + yield nt(*it.next()) try: from collections import namedtuple diff --git a/tests/test_extras_dictcursor.py b/tests/test_extras_dictcursor.py index 3bdb3ba3..dd746379 100755 --- a/tests/test_extras_dictcursor.py +++ b/tests/test_extras_dictcursor.py @@ -222,12 +222,14 @@ class NamedTupleCursorTest(unittest.TestCase): @skip_if_no_namedtuple def test_fetchone(self): curs = self.conn.cursor() - curs.execute("select * from nttest where i = 1") + curs.execute("select * from nttest order by 1") t = curs.fetchone() self.assertEqual(t[0], 1) self.assertEqual(t.i, 1) self.assertEqual(t[1], 'foo') self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_fetchmany_noarg(self): @@ -240,6 +242,8 @@ class NamedTupleCursorTest(unittest.TestCase): self.assertEqual(res[0].s, 'foo') self.assertEqual(res[1].i, 2) self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_fetchmany(self): @@ -251,6 +255,8 @@ class NamedTupleCursorTest(unittest.TestCase): self.assertEqual(res[0].s, 'foo') self.assertEqual(res[1].i, 2) self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_fetchall(self): @@ -264,6 +270,8 @@ class NamedTupleCursorTest(unittest.TestCase): self.assertEqual(res[1].s, 'bar') self.assertEqual(res[2].i, 3) self.assertEqual(res[2].s, 'baz') + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) @skip_if_no_namedtuple def test_executemany(self): @@ -281,16 +289,26 @@ class NamedTupleCursorTest(unittest.TestCase): curs = self.conn.cursor() curs.execute("select * from nttest order by 1") i = iter(curs) + self.assertEqual(curs.rownumber, 0) + t = i.next() self.assertEqual(t.i, 1) self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) + t = i.next() self.assertEqual(t.i, 2) self.assertEqual(t.s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) + t = i.next() self.assertEqual(t.i, 3) self.assertEqual(t.s, 'baz') self.assertRaises(StopIteration, i.next) + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) def test_error_message(self): try: @@ -415,6 +433,17 @@ class NamedTupleCursorTest(unittest.TestCase): self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) + @skip_if_no_namedtuple + @skip_before_postgres(8, 0) + def test_named_rownumber(self): + curs = self.conn.cursor('tmp') + # Only checking for dataset < itersize: + # see CursorTests.test_iter_named_cursor_rownumber + curs.itersize = 4 + curs.execute("""select * from generate_series(1,3)""") + for i, t in enumerate(curs): + self.assertEqual(i + 1, curs.rownumber) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)