Fixed NamedTupleCursor rownumber during iteration.

The correction is similar to the other one for the other subclasses.

Also added tests for rowcount and rownumber during different fetch styles.
Just in case.
This commit is contained in:
Daniele Varrazzo 2012-02-23 22:58:58 +00:00
parent ebec522a07
commit b8597dc1d3
2 changed files with 40 additions and 8 deletions

View File

@ -310,14 +310,17 @@ class NamedTupleCursor(_cursor):
return [nt(*t) for t in ts] return [nt(*t) for t in ts]
def __iter__(self): def __iter__(self):
# Invoking _cursor.__iter__(self) goes to infinite recursion, it = _cursor.__iter__(self)
# so we do pagination by hand t = it.next()
nt = self.Record
if nt is None:
nt = self.Record = self._make_nt()
yield nt(*t)
while 1: while 1:
recs = self.fetchmany(self.itersize) yield nt(*it.next())
if not recs:
return
for rec in recs:
yield rec
try: try:
from collections import namedtuple from collections import namedtuple

View File

@ -222,12 +222,14 @@ class NamedTupleCursorTest(unittest.TestCase):
@skip_if_no_namedtuple @skip_if_no_namedtuple
def test_fetchone(self): def test_fetchone(self):
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("select * from nttest where i = 1") curs.execute("select * from nttest order by 1")
t = curs.fetchone() t = curs.fetchone()
self.assertEqual(t[0], 1) self.assertEqual(t[0], 1)
self.assertEqual(t.i, 1) self.assertEqual(t.i, 1)
self.assertEqual(t[1], 'foo') self.assertEqual(t[1], 'foo')
self.assertEqual(t.s, 'foo') self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple @skip_if_no_namedtuple
def test_fetchmany_noarg(self): def test_fetchmany_noarg(self):
@ -240,6 +242,8 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assertEqual(res[0].s, 'foo') self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2) self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar') self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple @skip_if_no_namedtuple
def test_fetchmany(self): def test_fetchmany(self):
@ -251,6 +255,8 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assertEqual(res[0].s, 'foo') self.assertEqual(res[0].s, 'foo')
self.assertEqual(res[1].i, 2) self.assertEqual(res[1].i, 2)
self.assertEqual(res[1].s, 'bar') self.assertEqual(res[1].s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple @skip_if_no_namedtuple
def test_fetchall(self): def test_fetchall(self):
@ -264,6 +270,8 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assertEqual(res[1].s, 'bar') self.assertEqual(res[1].s, 'bar')
self.assertEqual(res[2].i, 3) self.assertEqual(res[2].i, 3)
self.assertEqual(res[2].s, 'baz') self.assertEqual(res[2].s, 'baz')
self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3)
@skip_if_no_namedtuple @skip_if_no_namedtuple
def test_executemany(self): def test_executemany(self):
@ -281,16 +289,26 @@ class NamedTupleCursorTest(unittest.TestCase):
curs = self.conn.cursor() curs = self.conn.cursor()
curs.execute("select * from nttest order by 1") curs.execute("select * from nttest order by 1")
i = iter(curs) i = iter(curs)
self.assertEqual(curs.rownumber, 0)
t = i.next() t = i.next()
self.assertEqual(t.i, 1) self.assertEqual(t.i, 1)
self.assertEqual(t.s, 'foo') self.assertEqual(t.s, 'foo')
self.assertEqual(curs.rownumber, 1)
self.assertEqual(curs.rowcount, 3)
t = i.next() t = i.next()
self.assertEqual(t.i, 2) self.assertEqual(t.i, 2)
self.assertEqual(t.s, 'bar') self.assertEqual(t.s, 'bar')
self.assertEqual(curs.rownumber, 2)
self.assertEqual(curs.rowcount, 3)
t = i.next() t = i.next()
self.assertEqual(t.i, 3) self.assertEqual(t.i, 3)
self.assertEqual(t.s, 'baz') self.assertEqual(t.s, 'baz')
self.assertRaises(StopIteration, i.next) self.assertRaises(StopIteration, i.next)
self.assertEqual(curs.rownumber, 3)
self.assertEqual(curs.rowcount, 3)
def test_error_message(self): def test_error_message(self):
try: try:
@ -415,6 +433,17 @@ class NamedTupleCursorTest(unittest.TestCase):
self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005))
self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099))
@skip_if_no_namedtuple
@skip_before_postgres(8, 0)
def test_named_rownumber(self):
curs = self.conn.cursor('tmp')
# Only checking for dataset < itersize:
# see CursorTests.test_iter_named_cursor_rownumber
curs.itersize = 4
curs.execute("""select * from generate_series(1,3)""")
for i, t in enumerate(curs):
self.assertEqual(i + 1, curs.rownumber)
def test_suite(): def test_suite():
return unittest.TestLoader().loadTestsFromName(__name__) return unittest.TestLoader().loadTestsFromName(__name__)