diff --git a/channels/asgi.py b/channels/asgi.py index 737a422..6156a48 100644 --- a/channels/asgi.py +++ b/channels/asgi.py @@ -56,6 +56,16 @@ class ChannelLayerManager(object): def __contains__(self, key): return key in self.configs + def set(self, key, layer): + """ + Sets an alias to point to a new ChannelLayerWrapper instance, and + returns the old one that it replaced. Useful for swapping out the + backend during tests. + """ + old = self.backends.get(key, None) + self.backends[key] = layer + return old + class ChannelLayerWrapper(object): """ diff --git a/channels/tests/__init__.py b/channels/tests/__init__.py index e69de29..566d872 100644 --- a/channels/tests/__init__.py +++ b/channels/tests/__init__.py @@ -0,0 +1 @@ +from .base import ChannelTestCase diff --git a/channels/tests/base.py b/channels/tests/base.py new file mode 100644 index 0000000..bed25f1 --- /dev/null +++ b/channels/tests/base.py @@ -0,0 +1,61 @@ +from django.test import TestCase +from channels import DEFAULT_CHANNEL_LAYER +from channels.asgi import channel_layers, ChannelLayerWrapper +from channels.message import Message +from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer + + +class ChannelTestCase(TestCase): + """ + TestCase subclass that provides easy methods for testing channels using + an in-memory backend to capture messages, and assertion methods to allow + checking of what was sent. + + Inherits from TestCase, so provides per-test transactions as long as the + database backend supports it. + """ + + # Customizable so users can test multi-layer setups + test_channel_aliases = [DEFAULT_CHANNEL_LAYER] + + def setUp(self): + """ + Initialises in memory channel layer for the duration of the test + """ + super(ChannelTestCase, self).setUp() + self._old_layers = {} + for alias in self.test_channel_aliases: + # Swap in an in memory layer wrapper and keep the old one around + self._old_layers[alias] = channel_layers.set( + alias, + ChannelLayerWrapper( + InMemoryChannelLayer(), + alias, + channel_layers[alias].routing, + ) + ) + + def tearDown(self): + """ + Undoes the channel rerouting + """ + for alias in self.test_channel_aliases: + # Swap in an in memory layer wrapper and keep the old one around + channel_layers.set(alias, self._old_layers[alias]) + del self._old_layers + super(ChannelTestCase, self).tearDown() + + def get_next_message(self, channel, alias=DEFAULT_CHANNEL_LAYER, require=False): + """ + Gets the next message that was sent to the channel during the test, + or None if no message is available. + + If require is true, will fail the test if no message is received. + """ + recv_channel, content = channel_layers[alias].receive_many([channel]) + if recv_channel is None: + if require: + self.fail("Expected a message on channel %s, got none" % channel) + else: + return None + return Message(content, recv_channel, channel_layers[alias]) diff --git a/channels/tests/settings.py b/channels/tests/settings.py index c472c7e..c06b6b4 100644 --- a/channels/tests/settings.py +++ b/channels/tests/settings.py @@ -6,4 +6,11 @@ DATABASES = { } } +CHANNEL_LAYERS = { + 'default': { + 'BACKEND': 'asgiref.inmemory.ChannelLayer', + 'ROUTING': [], + }, +} + MIDDLEWARE_CLASSES = [] diff --git a/channels/tests/test_handler.py b/channels/tests/test_handler.py index 516ec16..450f06b 100644 --- a/channels/tests/test_handler.py +++ b/channels/tests/test_handler.py @@ -1,10 +1,9 @@ from __future__ import unicode_literals -from django.test import SimpleTestCase from django.http import HttpResponse -from asgiref.inmemory import ChannelLayer +from channels import Channel from channels.handler import AsgiHandler -from channels.message import Message +from channels.tests import ChannelTestCase class FakeAsgiHandler(AsgiHandler): @@ -24,34 +23,27 @@ class FakeAsgiHandler(AsgiHandler): return self._response -class HandlerTests(SimpleTestCase): +class HandlerTests(ChannelTestCase): """ Tests that the handler works correctly and round-trips things into a correct response. """ - def setUp(self): - """ - Make an in memory channel layer for testing - """ - self.channel_layer = ChannelLayer() - self.make_message = lambda m, c: Message(m, c, self.channel_layer) - def test_basic(self): """ Tests a simple request """ # Make stub request and desired response - message = self.make_message({ + Channel("test").send({ "reply_channel": "test", "http_version": "1.1", "method": "GET", "path": b"/test/", - }, "test") + }) response = HttpResponse(b"Hi there!", content_type="text/plain") # Run the handler handler = FakeAsgiHandler(response) - reply_messages = list(handler(message)) + reply_messages = list(handler(self.get_next_message("test", require=True))) # Make sure we got the right number of messages self.assertEqual(len(reply_messages), 1) reply_message = reply_messages[0] @@ -69,16 +61,16 @@ class HandlerTests(SimpleTestCase): Tests a large response (will need chunking) """ # Make stub request and desired response - message = self.make_message({ + Channel("test").send({ "reply_channel": "test", "http_version": "1.1", "method": "GET", "path": b"/test/", - }, "test") + }) response = HttpResponse(b"Thefirstthirtybytesisrighthereandhereistherest") # Run the handler handler = FakeAsgiHandler(response) - reply_messages = list(handler(message)) + reply_messages = list(handler(self.get_next_message("test", require=True))) # Make sure we got the right number of messages self.assertEqual(len(reply_messages), 2) # Make sure the messages look correct diff --git a/channels/tests/test_request.py b/channels/tests/test_request.py index fe6f76c..7d1c4ec 100644 --- a/channels/tests/test_request.py +++ b/channels/tests/test_request.py @@ -1,36 +1,28 @@ from __future__ import unicode_literals -from django.test import SimpleTestCase from django.utils import six -from asgiref.inmemory import ChannelLayer +from channels import Channel +from channels.tests import ChannelTestCase from channels.handler import AsgiRequest -from channels.message import Message -class RequestTests(SimpleTestCase): +class RequestTests(ChannelTestCase): """ Tests that ASGI request handling correctly decodes HTTP requests. """ - def setUp(self): - """ - Make an in memory channel layer for testing - """ - self.channel_layer = ChannelLayer() - self.make_message = lambda m, c: Message(m, c, self.channel_layer) - def test_basic(self): """ Tests that the handler can decode the most basic request message, with all optional fields omitted. """ - message = self.make_message({ + Channel("test").send({ "reply_channel": "test-reply", "http_version": "1.1", "method": "GET", "path": b"/test/", - }, "test") - request = AsgiRequest(message) + }) + request = AsgiRequest(self.get_next_message("test")) self.assertEqual(request.path, "/test/") self.assertEqual(request.method, "GET") self.assertFalse(request.body) @@ -48,7 +40,7 @@ class RequestTests(SimpleTestCase): """ Tests a more fully-featured GET request """ - message = self.make_message({ + Channel("test").send({ "reply_channel": "test", "http_version": "1.1", "method": "GET", @@ -60,8 +52,8 @@ class RequestTests(SimpleTestCase): }, "client": ["10.0.0.1", 1234], "server": ["10.0.0.2", 80], - }, "test") - request = AsgiRequest(message) + }) + request = AsgiRequest(self.get_next_message("test")) self.assertEqual(request.path, "/test2/") self.assertEqual(request.method, "GET") self.assertFalse(request.body) @@ -81,7 +73,7 @@ class RequestTests(SimpleTestCase): """ Tests a POST body contained within a single message. """ - message = self.make_message({ + Channel("test").send({ "reply_channel": "test", "http_version": "1.1", "method": "POST", @@ -93,8 +85,8 @@ class RequestTests(SimpleTestCase): "content-type": b"application/x-www-form-urlencoded", "content-length": b"18", }, - }, "test") - request = AsgiRequest(message) + }) + request = AsgiRequest(self.get_next_message("test")) self.assertEqual(request.path, "/test2/") self.assertEqual(request.method, "POST") self.assertEqual(request.body, b"ponies=are+awesome") @@ -111,7 +103,7 @@ class RequestTests(SimpleTestCase): """ Tests a POST body across multiple messages (first part in 'body'). """ - message = self.make_message({ + Channel("test").send({ "reply_channel": "test", "http_version": "1.1", "method": "POST", @@ -123,15 +115,15 @@ class RequestTests(SimpleTestCase): "content-type": b"application/x-www-form-urlencoded", "content-length": b"21", }, - }, "test") - self.channel_layer.send("test-input", { + }) + Channel("test-input").send({ "content": b"re=fou", "more_content": True, }) - self.channel_layer.send("test-input", { + Channel("test-input").send({ "content": b"r+lights", }) - request = AsgiRequest(message) + request = AsgiRequest(self.get_next_message("test")) self.assertEqual(request.method, "POST") self.assertEqual(request.body, b"there_are=four+lights") self.assertEqual(request.META["CONTENT_TYPE"], "application/x-www-form-urlencoded") @@ -151,7 +143,7 @@ class RequestTests(SimpleTestCase): b'FAKEPDFBYTESGOHERE' + b'--BOUNDARY--' ) - message = self.make_message({ + Channel("test").send({ "reply_channel": "test", "http_version": "1.1", "method": "POST", @@ -161,15 +153,15 @@ class RequestTests(SimpleTestCase): "content-type": b"multipart/form-data; boundary=BOUNDARY", "content-length": six.text_type(len(body)).encode("ascii"), }, - }, "test") - self.channel_layer.send("test-input", { + }) + Channel("test-input").send({ "content": body[:20], "more_content": True, }) - self.channel_layer.send("test-input", { + Channel("test-input").send({ "content": body[20:], }) - request = AsgiRequest(message) + request = AsgiRequest(self.get_next_message("test")) self.assertEqual(request.method, "POST") self.assertEqual(len(request.body), len(body)) self.assertTrue(request.META["CONTENT_TYPE"].startswith("multipart/form-data")) @@ -181,7 +173,7 @@ class RequestTests(SimpleTestCase): """ Tests the body stream is emulated correctly. """ - message = self.make_message({ + Channel("test").send({ "reply_channel": "test", "http_version": "1.1", "method": "PUT", @@ -191,8 +183,8 @@ class RequestTests(SimpleTestCase): "host": b"example.com", "content-length": b"11", }, - }, "test") - request = AsgiRequest(message) + }) + request = AsgiRequest(self.get_next_message("test", require=True)) self.assertEqual(request.method, "PUT") self.assertEqual(request.read(3), b"one") self.assertEqual(request.read(), b"twothree")