diff --git a/channels/sessions.py b/channels/sessions.py index ee23f54..3220f27 100644 --- a/channels/sessions.py +++ b/channels/sessions.py @@ -75,6 +75,26 @@ def channel_session(func): return inner +def requeue_messages(message): + """ + Requeue any pending wait channel messages for this socket connection back onto it's original channel + """ + while True: + wait_channel = "__wait__.%s" % message.reply_channel.name + channel, content = message.channel_layer.receive_many([wait_channel], block=False) + if channel: + original_channel = content.pop("original_channel") + try: + message.channel_layer.send(original_channel, content) + except message.channel_layer.ChannelFull: + raise message.channel_layer.ChannelFull( + "Cannot requeue pending __wait__ channel message " + + "back on to already full channel %s" % original_channel + ) + else: + break + + def enforce_ordering(func=None, slight=False): """ Enforces strict (all messages exactly ordered) ordering against a reply_channel. @@ -106,21 +126,7 @@ def enforce_ordering(func=None, slight=False): message.channel_session["__channels_next_order"] = order + 1 message.channel_session.save() message.channel_session.modified = False - # Requeue any pending wait channel messages for this socket connection back onto it's original channel - while True: - wait_channel = "__wait__.%s" % message.reply_channel.name - channel, content = message.channel_layer.receive_many([wait_channel], block=False) - if channel: - original_channel = content.pop("original_channel") - try: - message.channel_layer.send(original_channel, content) - except message.channel_layer.ChannelFull: - raise message.channel_layer.ChannelFull( - "Cannot requeue pending __wait__ channel message " + - "back on to already full channel %s" % original_channel - ) - else: - break + requeue_messages(message) else: # Since out of order, enqueue message temporarily to wait channel for this socket connection wait_channel = "__wait__.%s" % message.reply_channel.name @@ -132,6 +138,11 @@ def enforce_ordering(func=None, slight=False): "Cannot add unordered message to already " + "full __wait__ channel for socket %s" % message.reply_channel.name ) + # Next order may have changed while this message was being processed + # Requeue messages if this has happened + if order == message.channel_session.load().get("__channels_next_order", 0): + requeue_messages(message) + return inner if func is not None: return decorator(func) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 0f8bf00..b56b2ad 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -10,6 +10,11 @@ from channels.sessions import ( ) from channels.test import ChannelTestCase +try: + from unittest import mock +except ImportError: + import mock + @override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache") class SessionTests(ChannelTestCase): @@ -223,3 +228,58 @@ class SessionTests(ChannelTestCase): with self.assertRaises(ValueError): inner(message0) + + def test_enforce_ordering_concurrent(self): + """ + Tests that strict mode of enforce_ordering puts messages in the correct queue after + the current message number changes while the message is being processed + """ + # Construct messages to send + message0 = Message( + {"reply_channel": "test-reply-e", "order": 0}, + "websocket.connect", + channel_layers[DEFAULT_CHANNEL_LAYER] + ) + message2 = Message( + {"reply_channel": "test-reply-e", "order": 2}, + "websocket.receive", + channel_layers[DEFAULT_CHANNEL_LAYER] + ) + message3 = Message( + {"reply_channel": "test-reply-e", "order": 3}, + "websocket.receive", + channel_layers[DEFAULT_CHANNEL_LAYER] + ) + + @channel_session + def add_session(message): + pass + + # Run them in an acceptable strict order + @enforce_ordering + def inner(message): + pass + + inner(message0) + inner(message3) + + # Add the session now so it can be mocked + add_session(message2) + + with mock.patch.object(message2.channel_session, 'load', return_value={'__channels_next_order': 2}): + inner(message2) + + # Ensure wait channel is empty + wait_channel = "__wait__.%s" % "test-reply-e" + next_message = self.get_next_message(wait_channel) + self.assertEqual(next_message, None) + + # Ensure messages 3 and 2 both ended up back on the original channel + expected = { + 2: message2, + 3: message3 + } + for m in range(2): + message = self.get_next_message("websocket.receive") + expected.pop(message.content['order']) + self.assertEqual(expected, {})