Requeue next message immediately to avoid wait queue race condition (#532)

Changes the strategy so that after a message has been put on the wait queue, it is then checked to see if it became the next message during this time and if so, immediately flushed. Will hopefully fix #451.
This commit is contained in:
Coread 2017-02-22 19:00:50 +00:00 committed by Andrew Godwin
parent db3a020122
commit 863b1cebdd
2 changed files with 86 additions and 15 deletions

View File

@ -75,6 +75,26 @@ def channel_session(func):
return inner
def requeue_messages(message):
"""
Requeue any pending wait channel messages for this socket connection back onto it's original channel
"""
while True:
wait_channel = "__wait__.%s" % message.reply_channel.name
channel, content = message.channel_layer.receive_many([wait_channel], block=False)
if channel:
original_channel = content.pop("original_channel")
try:
message.channel_layer.send(original_channel, content)
except message.channel_layer.ChannelFull:
raise message.channel_layer.ChannelFull(
"Cannot requeue pending __wait__ channel message " +
"back on to already full channel %s" % original_channel
)
else:
break
def enforce_ordering(func=None, slight=False):
"""
Enforces strict (all messages exactly ordered) ordering against a reply_channel.
@ -106,21 +126,7 @@ def enforce_ordering(func=None, slight=False):
message.channel_session["__channels_next_order"] = order + 1
message.channel_session.save()
message.channel_session.modified = False
# Requeue any pending wait channel messages for this socket connection back onto it's original channel
while True:
wait_channel = "__wait__.%s" % message.reply_channel.name
channel, content = message.channel_layer.receive_many([wait_channel], block=False)
if channel:
original_channel = content.pop("original_channel")
try:
message.channel_layer.send(original_channel, content)
except message.channel_layer.ChannelFull:
raise message.channel_layer.ChannelFull(
"Cannot requeue pending __wait__ channel message " +
"back on to already full channel %s" % original_channel
)
else:
break
requeue_messages(message)
else:
# Since out of order, enqueue message temporarily to wait channel for this socket connection
wait_channel = "__wait__.%s" % message.reply_channel.name
@ -132,6 +138,11 @@ def enforce_ordering(func=None, slight=False):
"Cannot add unordered message to already " +
"full __wait__ channel for socket %s" % message.reply_channel.name
)
# Next order may have changed while this message was being processed
# Requeue messages if this has happened
if order == message.channel_session.load().get("__channels_next_order", 0):
requeue_messages(message)
return inner
if func is not None:
return decorator(func)

View File

@ -10,6 +10,11 @@ from channels.sessions import (
)
from channels.test import ChannelTestCase
try:
from unittest import mock
except ImportError:
import mock
@override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache")
class SessionTests(ChannelTestCase):
@ -223,3 +228,58 @@ class SessionTests(ChannelTestCase):
with self.assertRaises(ValueError):
inner(message0)
def test_enforce_ordering_concurrent(self):
"""
Tests that strict mode of enforce_ordering puts messages in the correct queue after
the current message number changes while the message is being processed
"""
# Construct messages to send
message0 = Message(
{"reply_channel": "test-reply-e", "order": 0},
"websocket.connect",
channel_layers[DEFAULT_CHANNEL_LAYER]
)
message2 = Message(
{"reply_channel": "test-reply-e", "order": 2},
"websocket.receive",
channel_layers[DEFAULT_CHANNEL_LAYER]
)
message3 = Message(
{"reply_channel": "test-reply-e", "order": 3},
"websocket.receive",
channel_layers[DEFAULT_CHANNEL_LAYER]
)
@channel_session
def add_session(message):
pass
# Run them in an acceptable strict order
@enforce_ordering
def inner(message):
pass
inner(message0)
inner(message3)
# Add the session now so it can be mocked
add_session(message2)
with mock.patch.object(message2.channel_session, 'load', return_value={'__channels_next_order': 2}):
inner(message2)
# Ensure wait channel is empty
wait_channel = "__wait__.%s" % "test-reply-e"
next_message = self.get_next_message(wait_channel)
self.assertEqual(next_message, None)
# Ensure messages 3 and 2 both ended up back on the original channel
expected = {
2: message2,
3: message3
}
for m in range(2):
message = self.get_next_message("websocket.receive")
expected.pop(message.content['order'])
self.assertEqual(expected, {})