diff --git a/channels/sessions.py b/channels/sessions.py index 3220f27..3af6da9 100644 --- a/channels/sessions.py +++ b/channels/sessions.py @@ -8,6 +8,7 @@ from django.contrib.sessions.backends.base import CreateError from .exceptions import ConsumeLater from .handler import AsgiRequest +from .message import Message def session_for_reply_channel(reply_channel): @@ -39,11 +40,18 @@ def channel_session(func): Use this to persist data across the lifetime of a connection. """ @functools.wraps(func) - def inner(message, *args, **kwargs): + def inner(*args, **kwargs): + message = None + for arg in args[:2]: + if isinstance(arg, Message): + message = arg + break + if message is None: + raise ValueError('channel_session called without Message instance') # Make sure there's NOT a channel_session already if hasattr(message, "channel_session"): try: - return func(message, *args, **kwargs) + return func(*args, **kwargs) finally: # Persist session if needed if message.channel_session.modified: @@ -67,7 +75,7 @@ def channel_session(func): message.channel_session = session # Run the consumer try: - return func(message, *args, **kwargs) + return func(*args, **kwargs) finally: # Persist session if needed if session.modified and not session.is_empty(): diff --git a/tests/test_sessions.py b/tests/test_sessions.py index b56b2ad..d1d507a 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -49,6 +49,39 @@ class SessionTests(ChannelTestCase): session2 = session_for_reply_channel("test-reply") self.assertEqual(session2["num_ponies"], -1) + def test_channel_session_method(self): + """ + Tests the channel_session decorator works on methods + """ + # Construct message to send + message = Message({"reply_channel": "test-reply"}, None, None) + + # Run through a simple fake consumer that assigns to it + class Consumer(object): + @channel_session + def inner(self, message): + message.channel_session["num_ponies"] = -1 + + Consumer().inner(message) + # Test the session worked + session2 = session_for_reply_channel("test-reply") + self.assertEqual(session2["num_ponies"], -1) + + def test_channel_session_third_arg(self): + """ + Tests the channel_session decorator with message as 3rd argument + """ + # Construct message to send + message = Message({"reply_channel": "test-reply"}, None, None) + + # Run through a simple fake consumer that assigns to it + @channel_session + def inner(a, b, message): + message.channel_session["num_ponies"] = -1 + + with self.assertRaisesMessage(ValueError, 'channel_session called without Message instance'): + inner(None, None, message) + def test_channel_session_double(self): """ Tests the channel_session decorator detects being wrapped in itself @@ -68,6 +101,42 @@ class SessionTests(ChannelTestCase): session2 = session_for_reply_channel("test-reply") self.assertEqual(session2["num_ponies"], -1) + def test_channel_session_double_method(self): + """ + Tests the channel_session decorator detects being wrapped in itself + and doesn't blow up. Method version. + """ + # Construct message to send + message = Message({"reply_channel": "test-reply"}, None, None) + + # Run through a simple fake consumer that should trigger the error + class Consumer(object): + @channel_session + @channel_session + def inner(self, message): + message.channel_session["num_ponies"] = -1 + Consumer().inner(message) + + # Test the session worked + session2 = session_for_reply_channel("test-reply") + self.assertEqual(session2["num_ponies"], -1) + + def test_channel_session_double_third_arg(self): + """ + Tests the channel_session decorator detects being wrapped in itself + and doesn't blow up. + """ + # Construct message to send + message = Message({"reply_channel": "test-reply"}, None, None) + + # Run through a simple fake consumer that should trigger the error + @channel_session + @channel_session + def inner(a, b, message): + message.channel_session["num_ponies"] = -1 + with self.assertRaisesMessage(ValueError, 'channel_session called without Message instance'): + inner(None, None, message) + def test_channel_session_no_reply(self): """ Tests the channel_session decorator detects no reply channel @@ -84,6 +153,39 @@ class SessionTests(ChannelTestCase): with self.assertRaises(ValueError): inner(message) + def test_channel_session_no_reply_method(self): + """ + Tests the channel_session decorator detects no reply channel + """ + # Construct message to send + message = Message({}, None, None) + + # Run through a simple fake consumer that should trigger the error + class Consumer(object): + @channel_session + @channel_session + def inner(self, message): + message.channel_session["num_ponies"] = -1 + + with self.assertRaises(ValueError): + Consumer().inner(message) + + def test_channel_session_no_reply_third_arg(self): + """ + Tests the channel_session decorator detects no reply channel + """ + # Construct message to send + message = Message({}, None, None) + + # Run through a simple fake consumer that should trigger the error + @channel_session + @channel_session + def inner(a, b, message): + message.channel_session["num_ponies"] = -1 + + with self.assertRaisesMessage(ValueError, 'channel_session called without Message instance'): + inner(None, None, message) + def test_http_session(self): """ Tests that http_session correctly extracts a session cookie.