From 863b1cebdda1b8e561ba01738d0f7c60dc615d58 Mon Sep 17 00:00:00 2001 From: Coread Date: Wed, 22 Feb 2017 19:00:50 +0000 Subject: [PATCH] Requeue next message immediately to avoid wait queue race condition (#532) Changes the strategy so that after a message has been put on the wait queue, it is then checked to see if it became the next message during this time and if so, immediately flushed. Will hopefully fix #451. --- channels/sessions.py | 41 ++++++++++++++++++----------- tests/test_sessions.py | 60 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 15 deletions(-) diff --git a/channels/sessions.py b/channels/sessions.py index ee23f54..3220f27 100644 --- a/channels/sessions.py +++ b/channels/sessions.py @@ -75,6 +75,26 @@ def channel_session(func): return inner +def requeue_messages(message): + """ + Requeue any pending wait channel messages for this socket connection back onto it's original channel + """ + while True: + wait_channel = "__wait__.%s" % message.reply_channel.name + channel, content = message.channel_layer.receive_many([wait_channel], block=False) + if channel: + original_channel = content.pop("original_channel") + try: + message.channel_layer.send(original_channel, content) + except message.channel_layer.ChannelFull: + raise message.channel_layer.ChannelFull( + "Cannot requeue pending __wait__ channel message " + + "back on to already full channel %s" % original_channel + ) + else: + break + + def enforce_ordering(func=None, slight=False): """ Enforces strict (all messages exactly ordered) ordering against a reply_channel. @@ -106,21 +126,7 @@ def enforce_ordering(func=None, slight=False): message.channel_session["__channels_next_order"] = order + 1 message.channel_session.save() message.channel_session.modified = False - # Requeue any pending wait channel messages for this socket connection back onto it's original channel - while True: - wait_channel = "__wait__.%s" % message.reply_channel.name - channel, content = message.channel_layer.receive_many([wait_channel], block=False) - if channel: - original_channel = content.pop("original_channel") - try: - message.channel_layer.send(original_channel, content) - except message.channel_layer.ChannelFull: - raise message.channel_layer.ChannelFull( - "Cannot requeue pending __wait__ channel message " + - "back on to already full channel %s" % original_channel - ) - else: - break + requeue_messages(message) else: # Since out of order, enqueue message temporarily to wait channel for this socket connection wait_channel = "__wait__.%s" % message.reply_channel.name @@ -132,6 +138,11 @@ def enforce_ordering(func=None, slight=False): "Cannot add unordered message to already " + "full __wait__ channel for socket %s" % message.reply_channel.name ) + # Next order may have changed while this message was being processed + # Requeue messages if this has happened + if order == message.channel_session.load().get("__channels_next_order", 0): + requeue_messages(message) + return inner if func is not None: return decorator(func) diff --git a/tests/test_sessions.py b/tests/test_sessions.py index 0f8bf00..b56b2ad 100644 --- a/tests/test_sessions.py +++ b/tests/test_sessions.py @@ -10,6 +10,11 @@ from channels.sessions import ( ) from channels.test import ChannelTestCase +try: + from unittest import mock +except ImportError: + import mock + @override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache") class SessionTests(ChannelTestCase): @@ -223,3 +228,58 @@ class SessionTests(ChannelTestCase): with self.assertRaises(ValueError): inner(message0) + + def test_enforce_ordering_concurrent(self): + """ + Tests that strict mode of enforce_ordering puts messages in the correct queue after + the current message number changes while the message is being processed + """ + # Construct messages to send + message0 = Message( + {"reply_channel": "test-reply-e", "order": 0}, + "websocket.connect", + channel_layers[DEFAULT_CHANNEL_LAYER] + ) + message2 = Message( + {"reply_channel": "test-reply-e", "order": 2}, + "websocket.receive", + channel_layers[DEFAULT_CHANNEL_LAYER] + ) + message3 = Message( + {"reply_channel": "test-reply-e", "order": 3}, + "websocket.receive", + channel_layers[DEFAULT_CHANNEL_LAYER] + ) + + @channel_session + def add_session(message): + pass + + # Run them in an acceptable strict order + @enforce_ordering + def inner(message): + pass + + inner(message0) + inner(message3) + + # Add the session now so it can be mocked + add_session(message2) + + with mock.patch.object(message2.channel_session, 'load', return_value={'__channels_next_order': 2}): + inner(message2) + + # Ensure wait channel is empty + wait_channel = "__wait__.%s" % "test-reply-e" + next_message = self.get_next_message(wait_channel) + self.assertEqual(next_message, None) + + # Ensure messages 3 and 2 both ended up back on the original channel + expected = { + 2: message2, + 3: message3 + } + for m in range(2): + message = self.get_next_message("websocket.receive") + expected.pop(message.content['order']) + self.assertEqual(expected, {})