diff --git a/lib/extensions.py b/lib/extensions.py index eea90d48..51461af3 100644 --- a/lib/extensions.py +++ b/lib/extensions.py @@ -58,7 +58,7 @@ except: from _psycopg import adapt, adapters, encodings, connection, cursor, lobject from _psycopg import string_types, binary_types, new_type, register_type -from _psycopg import ISQLQuote +from _psycopg import ISQLQuote, Notify from _psycopg import QueryCanceledError, TransactionRollbackError diff --git a/psycopg/psycopgmodule.c b/psycopg/psycopgmodule.c index 783bf159..1e5f8403 100644 --- a/psycopg/psycopgmodule.c +++ b/psycopg/psycopgmodule.c @@ -822,6 +822,7 @@ init_psycopg(void) PyModule_AddObject(module, "connection", (PyObject*)&connectionType); PyModule_AddObject(module, "cursor", (PyObject*)&cursorType); PyModule_AddObject(module, "ISQLQuote", (PyObject*)&isqlquoteType); + PyModule_AddObject(module, "Notify", (PyObject*)&NotifyType); #ifdef PSYCOPG_EXTENSIONS PyModule_AddObject(module, "lobject", (PyObject*)&lobjectType); #endif diff --git a/tests/test_notify.py b/tests/test_notify.py index acad27cd..fa6c654f 100755 --- a/tests/test_notify.py +++ b/tests/test_notify.py @@ -105,6 +105,14 @@ conn.close() self.assertEqual(pid, self.conn.notifies[0][0]) self.assertEqual('foo', self.conn.notifies[0][1]) + def test_notify_object(self): + self.autocommit(self.conn) + self.listen('foo') + self.notify('foo').communicate() + self.conn.poll() + notify = self.conn.notifies[0] + self.assert_(isinstance(notify, psycopg2.extensions.Notify)) + def test_notify_attributes(self): self.autocommit(self.conn) self.listen('foo') @@ -131,6 +139,21 @@ conn.close() self.assertEqual('foo', notify.channel) self.assertEqual('Hello, world!', notify.payload) + def test_notify_init(self): + n = psycopg2.extensions.Notify(10, 'foo') + self.assertEqual(10, n.pid) + self.assertEqual('foo', n.channel) + self.assertEqual(None, n.payload) + (pid, channel) = n + self.assertEqual((pid, channel), (10, 'foo')) + + n = psycopg2.extensions.Notify(42, 'bar', 'baz') + self.assertEqual(42, n.pid) + self.assertEqual('bar', n.channel) + self.assertEqual('baz', n.payload) + (pid, channel) = n + self.assertEqual((pid, channel), (42, 'bar')) + def test_suite(): return unittest.TestLoader().loadTestsFromName(__name__)