diff --git a/tests/test_extras_dictcursor.py b/tests/test_extras_dictcursor.py index 5f205ca9..6f1704a8 100755 --- a/tests/test_extras_dictcursor.py +++ b/tests/test_extras_dictcursor.py @@ -661,6 +661,196 @@ class NamedTupleCursorTest(ConnectingTestCase): NamedTupleCursor._cached_make_nt = old_func +class HybridRowTest(ConnectingTestCase): + + def setUp(self): + ConnectingTestCase.setUp(self) + self.conn = self.connect(cursor_factory=psycopg2.extras.HybridRowCursor) + curs = self.conn.cursor() + curs.execute("CREATE TEMPORARY TABLE test (i int, s text)") + curs.execute("INSERT INTO test VALUES (1, 'foo')") + curs.execute("INSERT INTO test VALUES (2, 'bar')") + curs.execute("INSERT INTO test VALUES (3, 'baz')") + self.conn.commit() + + def test_cursor_args(self): + cur = self.conn.cursor('foo', cursor_factory=psycopg2.extras.HybridRowCursor) + self.assertEqual(cur.name, 'foo') + self.assert_(isinstance(cur, psycopg2.extras.HybridRowCursor)) + + def test_fetchone(self): + curs = self.conn.cursor() + curs.execute("select * from test order by i") + r = curs.fetchone() + self.assert_(isinstance(r, psycopg2.extras.HybridRow)) + self.assertEqual(r[0], 1) + self.assertEqual(r.i, 1) + self.assertEqual(r['i'], 1) + self.assertEqual(r[1], 'foo') + self.assertEqual(r.s, 'foo') + self.assertEqual(r['s'], 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) + + def test_fetchmany_noarg(self): + curs = self.conn.cursor() + curs.arraysize = 2 + curs.execute("select * from test order by i") + res = curs.fetchmany() + self.assertEqual(2, len(res)) + self.assert_(isinstance(res[0], psycopg2.extras.HybridRow)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) + + def test_fetchmany(self): + curs = self.conn.cursor() + curs.execute("select * from test order by i") + res = curs.fetchmany(2) + self.assertEqual(2, len(res)) + self.assert_(isinstance(res[0], psycopg2.extras.HybridRow)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) + + def test_fetchall(self): + curs = self.conn.cursor() + curs.execute("select * from test order by i") + res = curs.fetchall() + self.assertEqual(3, len(res)) + self.assert_(isinstance(res[0], psycopg2.extras.HybridRow)) + self.assertEqual(res[0].i, 1) + self.assertEqual(res[0].s, 'foo') + self.assertEqual(res[1].i, 2) + self.assertEqual(res[1].s, 'bar') + self.assertEqual(res[2].i, 3) + self.assertEqual(res[2].s, 'baz') + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) + + def test_executemany(self): + curs = self.conn.cursor() + curs.executemany("delete from test where i = %s", + [(1,), (2,)]) + curs.execute("select * from test order by i") + res = curs.fetchall() + self.assertEqual(1, len(res)) + self.assert_(isinstance(res[0], psycopg2.extras.HybridRow)) + self.assertEqual(res[0].i, 3) + self.assertEqual(res[0].s, 'baz') + + def test_iter(self): + curs = self.conn.cursor() + curs.execute("select * from test order by i") + i = iter(curs) + self.assertEqual(curs.rownumber, 0) + + t = next(i) + self.assert_(isinstance(t, psycopg2.extras.HybridRow)) + self.assertEqual(t.i, 1) + self.assertEqual(t.s, 'foo') + self.assertEqual(curs.rownumber, 1) + self.assertEqual(curs.rowcount, 3) + + t = next(i) + self.assert_(isinstance(t, psycopg2.extras.HybridRow)) + self.assertEqual(t.i, 2) + self.assertEqual(t.s, 'bar') + self.assertEqual(curs.rownumber, 2) + self.assertEqual(curs.rowcount, 3) + + t = next(i) + self.assert_(isinstance(t, psycopg2.extras.HybridRow)) + self.assertEqual(t.i, 3) + self.assertEqual(t.s, 'baz') + self.assertRaises(StopIteration, next, i) + self.assertEqual(curs.rownumber, 3) + self.assertEqual(curs.rowcount, 3) + + def test_record_updated(self): + curs = self.conn.cursor() + curs.execute("select 1 as foo;") + r = curs.fetchone() + self.assertEqual(r.foo, 1) + + curs.execute("select 2 as bar;") + r = curs.fetchone() + self.assertEqual(r.bar, 2) + self.assertRaises(AttributeError, getattr, r, 'foo') + + def test_row_unpack(self): + curs = self.conn.cursor() + curs.execute("select 1 as foo, 2 as bar;") + r = curs.fetchone() + foo, bar = r + self.assertEqual(foo, 1) + self.assertEqual(bar, 2) + + def test_row_comparison(self): + curs = self.conn.cursor() + curs.execute("select 1 as foo, 2 as bar;") + r = curs.fetchone() + self.assertEqual(r, r) + self.assertEqual(r, (1, 2)) + self.assertEqual(r, {'foo': 1, 'bar': 2}) + self.assertNotEqual(r, None) + + def test_no_result_no_surprise(self): + curs = self.conn.cursor() + curs.execute("update test set s = s") + self.assertRaises(psycopg2.ProgrammingError, curs.fetchone) + + curs.execute("update test set s = s") + self.assertRaises(psycopg2.ProgrammingError, curs.fetchall) + + def test_special_col_names(self): + curs = self.conn.cursor() + curs.execute('select 1 as "foo.bar_baz", 2 as "?column?", 3 as "3"') + r = curs.fetchone() + self.assertEqual(r['foo.bar_baz'], 1) + self.assertEqual(r['?column?'], 2) + self.assertEqual(r['3'], 3) + + @skip_before_python(3) + @skip_before_postgres(8) + def test_nonascii_name(self): + curs = self.conn.cursor() + curs.execute('select 1 as \xe5h\xe9') + rv = curs.fetchone() + self.assertEqual(getattr(rv, '\xe5h\xe9'), 1) + + @skip_before_postgres(8, 0) + def test_named_cursor(self): + curs = self.conn.cursor('tmp') + curs.execute("""select i from generate_series(0,9) i""") + recs = [] + recs.extend(curs.fetchmany(5)) + recs.append(curs.fetchone()) + recs.extend(curs.fetchall()) + self.assert_(all(isinstance(r, psycopg2.extras.HybridRow) for r in recs)) + self.assertEqual(list(range(10)), [t.i for t in recs]) + + @skip_before_postgres(8, 2) + def test_not_greedy(self): + curs = self.conn.cursor('tmp') + curs.itersize = 2 + curs.execute("""select clock_timestamp() as ts from generate_series(1,3)""") + recs = [] + for t in curs: + time.sleep(0.01) + recs.append(t) + + # check that the dataset was not fetched in a single gulp + self.assert_(recs[1].ts - recs[0].ts < timedelta(seconds=0.005)) + self.assert_(recs[2].ts - recs[1].ts > timedelta(seconds=0.0099)) + + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)