Introduce ChannelTestCase to make testing easier

This commit is contained in:
Andrew Godwin 2016-04-03 18:32:42 +02:00
parent 3576267be2
commit 0e3c742a80
6 changed files with 113 additions and 50 deletions

View File

@ -56,6 +56,16 @@ class ChannelLayerManager(object):
def __contains__(self, key): def __contains__(self, key):
return key in self.configs return key in self.configs
def set(self, key, layer):
"""
Sets an alias to point to a new ChannelLayerWrapper instance, and
returns the old one that it replaced. Useful for swapping out the
backend during tests.
"""
old = self.backends.get(key, None)
self.backends[key] = layer
return old
class ChannelLayerWrapper(object): class ChannelLayerWrapper(object):
""" """

View File

@ -0,0 +1 @@
from .base import ChannelTestCase

61
channels/tests/base.py Normal file
View File

@ -0,0 +1,61 @@
from django.test import TestCase
from channels import DEFAULT_CHANNEL_LAYER
from channels.asgi import channel_layers, ChannelLayerWrapper
from channels.message import Message
from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer
class ChannelTestCase(TestCase):
"""
TestCase subclass that provides easy methods for testing channels using
an in-memory backend to capture messages, and assertion methods to allow
checking of what was sent.
Inherits from TestCase, so provides per-test transactions as long as the
database backend supports it.
"""
# Customizable so users can test multi-layer setups
test_channel_aliases = [DEFAULT_CHANNEL_LAYER]
def setUp(self):
"""
Initialises in memory channel layer for the duration of the test
"""
super(ChannelTestCase, self).setUp()
self._old_layers = {}
for alias in self.test_channel_aliases:
# Swap in an in memory layer wrapper and keep the old one around
self._old_layers[alias] = channel_layers.set(
alias,
ChannelLayerWrapper(
InMemoryChannelLayer(),
alias,
channel_layers[alias].routing,
)
)
def tearDown(self):
"""
Undoes the channel rerouting
"""
for alias in self.test_channel_aliases:
# Swap in an in memory layer wrapper and keep the old one around
channel_layers.set(alias, self._old_layers[alias])
del self._old_layers
super(ChannelTestCase, self).tearDown()
def get_next_message(self, channel, alias=DEFAULT_CHANNEL_LAYER, require=False):
"""
Gets the next message that was sent to the channel during the test,
or None if no message is available.
If require is true, will fail the test if no message is received.
"""
recv_channel, content = channel_layers[alias].receive_many([channel])
if recv_channel is None:
if require:
self.fail("Expected a message on channel %s, got none" % channel)
else:
return None
return Message(content, recv_channel, channel_layers[alias])

View File

@ -6,4 +6,11 @@ DATABASES = {
} }
} }
CHANNEL_LAYERS = {
'default': {
'BACKEND': 'asgiref.inmemory.ChannelLayer',
'ROUTING': [],
},
}
MIDDLEWARE_CLASSES = [] MIDDLEWARE_CLASSES = []

View File

@ -1,10 +1,9 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import SimpleTestCase
from django.http import HttpResponse from django.http import HttpResponse
from asgiref.inmemory import ChannelLayer from channels import Channel
from channels.handler import AsgiHandler from channels.handler import AsgiHandler
from channels.message import Message from channels.tests import ChannelTestCase
class FakeAsgiHandler(AsgiHandler): class FakeAsgiHandler(AsgiHandler):
@ -24,34 +23,27 @@ class FakeAsgiHandler(AsgiHandler):
return self._response return self._response
class HandlerTests(SimpleTestCase): class HandlerTests(ChannelTestCase):
""" """
Tests that the handler works correctly and round-trips things into a Tests that the handler works correctly and round-trips things into a
correct response. correct response.
""" """
def setUp(self):
"""
Make an in memory channel layer for testing
"""
self.channel_layer = ChannelLayer()
self.make_message = lambda m, c: Message(m, c, self.channel_layer)
def test_basic(self): def test_basic(self):
""" """
Tests a simple request Tests a simple request
""" """
# Make stub request and desired response # Make stub request and desired response
message = self.make_message({ Channel("test").send({
"reply_channel": "test", "reply_channel": "test",
"http_version": "1.1", "http_version": "1.1",
"method": "GET", "method": "GET",
"path": b"/test/", "path": b"/test/",
}, "test") })
response = HttpResponse(b"Hi there!", content_type="text/plain") response = HttpResponse(b"Hi there!", content_type="text/plain")
# Run the handler # Run the handler
handler = FakeAsgiHandler(response) handler = FakeAsgiHandler(response)
reply_messages = list(handler(message)) reply_messages = list(handler(self.get_next_message("test", require=True)))
# Make sure we got the right number of messages # Make sure we got the right number of messages
self.assertEqual(len(reply_messages), 1) self.assertEqual(len(reply_messages), 1)
reply_message = reply_messages[0] reply_message = reply_messages[0]
@ -69,16 +61,16 @@ class HandlerTests(SimpleTestCase):
Tests a large response (will need chunking) Tests a large response (will need chunking)
""" """
# Make stub request and desired response # Make stub request and desired response
message = self.make_message({ Channel("test").send({
"reply_channel": "test", "reply_channel": "test",
"http_version": "1.1", "http_version": "1.1",
"method": "GET", "method": "GET",
"path": b"/test/", "path": b"/test/",
}, "test") })
response = HttpResponse(b"Thefirstthirtybytesisrighthereandhereistherest") response = HttpResponse(b"Thefirstthirtybytesisrighthereandhereistherest")
# Run the handler # Run the handler
handler = FakeAsgiHandler(response) handler = FakeAsgiHandler(response)
reply_messages = list(handler(message)) reply_messages = list(handler(self.get_next_message("test", require=True)))
# Make sure we got the right number of messages # Make sure we got the right number of messages
self.assertEqual(len(reply_messages), 2) self.assertEqual(len(reply_messages), 2)
# Make sure the messages look correct # Make sure the messages look correct

View File

@ -1,36 +1,28 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.test import SimpleTestCase
from django.utils import six from django.utils import six
from asgiref.inmemory import ChannelLayer from channels import Channel
from channels.tests import ChannelTestCase
from channels.handler import AsgiRequest from channels.handler import AsgiRequest
from channels.message import Message
class RequestTests(SimpleTestCase): class RequestTests(ChannelTestCase):
""" """
Tests that ASGI request handling correctly decodes HTTP requests. Tests that ASGI request handling correctly decodes HTTP requests.
""" """
def setUp(self):
"""
Make an in memory channel layer for testing
"""
self.channel_layer = ChannelLayer()
self.make_message = lambda m, c: Message(m, c, self.channel_layer)
def test_basic(self): def test_basic(self):
""" """
Tests that the handler can decode the most basic request message, Tests that the handler can decode the most basic request message,
with all optional fields omitted. with all optional fields omitted.
""" """
message = self.make_message({ Channel("test").send({
"reply_channel": "test-reply", "reply_channel": "test-reply",
"http_version": "1.1", "http_version": "1.1",
"method": "GET", "method": "GET",
"path": b"/test/", "path": b"/test/",
}, "test") })
request = AsgiRequest(message) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test/") self.assertEqual(request.path, "/test/")
self.assertEqual(request.method, "GET") self.assertEqual(request.method, "GET")
self.assertFalse(request.body) self.assertFalse(request.body)
@ -48,7 +40,7 @@ class RequestTests(SimpleTestCase):
""" """
Tests a more fully-featured GET request Tests a more fully-featured GET request
""" """
message = self.make_message({ Channel("test").send({
"reply_channel": "test", "reply_channel": "test",
"http_version": "1.1", "http_version": "1.1",
"method": "GET", "method": "GET",
@ -60,8 +52,8 @@ class RequestTests(SimpleTestCase):
}, },
"client": ["10.0.0.1", 1234], "client": ["10.0.0.1", 1234],
"server": ["10.0.0.2", 80], "server": ["10.0.0.2", 80],
}, "test") })
request = AsgiRequest(message) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test2/") self.assertEqual(request.path, "/test2/")
self.assertEqual(request.method, "GET") self.assertEqual(request.method, "GET")
self.assertFalse(request.body) self.assertFalse(request.body)
@ -81,7 +73,7 @@ class RequestTests(SimpleTestCase):
""" """
Tests a POST body contained within a single message. Tests a POST body contained within a single message.
""" """
message = self.make_message({ Channel("test").send({
"reply_channel": "test", "reply_channel": "test",
"http_version": "1.1", "http_version": "1.1",
"method": "POST", "method": "POST",
@ -93,8 +85,8 @@ class RequestTests(SimpleTestCase):
"content-type": b"application/x-www-form-urlencoded", "content-type": b"application/x-www-form-urlencoded",
"content-length": b"18", "content-length": b"18",
}, },
}, "test") })
request = AsgiRequest(message) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test2/") self.assertEqual(request.path, "/test2/")
self.assertEqual(request.method, "POST") self.assertEqual(request.method, "POST")
self.assertEqual(request.body, b"ponies=are+awesome") self.assertEqual(request.body, b"ponies=are+awesome")
@ -111,7 +103,7 @@ class RequestTests(SimpleTestCase):
""" """
Tests a POST body across multiple messages (first part in 'body'). Tests a POST body across multiple messages (first part in 'body').
""" """
message = self.make_message({ Channel("test").send({
"reply_channel": "test", "reply_channel": "test",
"http_version": "1.1", "http_version": "1.1",
"method": "POST", "method": "POST",
@ -123,15 +115,15 @@ class RequestTests(SimpleTestCase):
"content-type": b"application/x-www-form-urlencoded", "content-type": b"application/x-www-form-urlencoded",
"content-length": b"21", "content-length": b"21",
}, },
}, "test") })
self.channel_layer.send("test-input", { Channel("test-input").send({
"content": b"re=fou", "content": b"re=fou",
"more_content": True, "more_content": True,
}) })
self.channel_layer.send("test-input", { Channel("test-input").send({
"content": b"r+lights", "content": b"r+lights",
}) })
request = AsgiRequest(message) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.method, "POST") self.assertEqual(request.method, "POST")
self.assertEqual(request.body, b"there_are=four+lights") self.assertEqual(request.body, b"there_are=four+lights")
self.assertEqual(request.META["CONTENT_TYPE"], "application/x-www-form-urlencoded") self.assertEqual(request.META["CONTENT_TYPE"], "application/x-www-form-urlencoded")
@ -151,7 +143,7 @@ class RequestTests(SimpleTestCase):
b'FAKEPDFBYTESGOHERE' + b'FAKEPDFBYTESGOHERE' +
b'--BOUNDARY--' b'--BOUNDARY--'
) )
message = self.make_message({ Channel("test").send({
"reply_channel": "test", "reply_channel": "test",
"http_version": "1.1", "http_version": "1.1",
"method": "POST", "method": "POST",
@ -161,15 +153,15 @@ class RequestTests(SimpleTestCase):
"content-type": b"multipart/form-data; boundary=BOUNDARY", "content-type": b"multipart/form-data; boundary=BOUNDARY",
"content-length": six.text_type(len(body)).encode("ascii"), "content-length": six.text_type(len(body)).encode("ascii"),
}, },
}, "test") })
self.channel_layer.send("test-input", { Channel("test-input").send({
"content": body[:20], "content": body[:20],
"more_content": True, "more_content": True,
}) })
self.channel_layer.send("test-input", { Channel("test-input").send({
"content": body[20:], "content": body[20:],
}) })
request = AsgiRequest(message) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.method, "POST") self.assertEqual(request.method, "POST")
self.assertEqual(len(request.body), len(body)) self.assertEqual(len(request.body), len(body))
self.assertTrue(request.META["CONTENT_TYPE"].startswith("multipart/form-data")) self.assertTrue(request.META["CONTENT_TYPE"].startswith("multipart/form-data"))
@ -181,7 +173,7 @@ class RequestTests(SimpleTestCase):
""" """
Tests the body stream is emulated correctly. Tests the body stream is emulated correctly.
""" """
message = self.make_message({ Channel("test").send({
"reply_channel": "test", "reply_channel": "test",
"http_version": "1.1", "http_version": "1.1",
"method": "PUT", "method": "PUT",
@ -191,8 +183,8 @@ class RequestTests(SimpleTestCase):
"host": b"example.com", "host": b"example.com",
"content-length": b"11", "content-length": b"11",
}, },
}, "test") })
request = AsgiRequest(message) request = AsgiRequest(self.get_next_message("test", require=True))
self.assertEqual(request.method, "PUT") self.assertEqual(request.method, "PUT")
self.assertEqual(request.read(3), b"one") self.assertEqual(request.read(3), b"one")
self.assertEqual(request.read(), b"twothree") self.assertEqual(request.read(), b"twothree")