Introduce ChannelTestCase to make testing easier

This commit is contained in:
Andrew Godwin 2016-04-03 18:32:42 +02:00
parent 3576267be2
commit 0e3c742a80
6 changed files with 113 additions and 50 deletions

View File

@ -56,6 +56,16 @@ class ChannelLayerManager(object):
def __contains__(self, key):
return key in self.configs
def set(self, key, layer):
"""
Sets an alias to point to a new ChannelLayerWrapper instance, and
returns the old one that it replaced. Useful for swapping out the
backend during tests.
"""
old = self.backends.get(key, None)
self.backends[key] = layer
return old
class ChannelLayerWrapper(object):
"""

View File

@ -0,0 +1 @@
from .base import ChannelTestCase

61
channels/tests/base.py Normal file
View File

@ -0,0 +1,61 @@
from django.test import TestCase
from channels import DEFAULT_CHANNEL_LAYER
from channels.asgi import channel_layers, ChannelLayerWrapper
from channels.message import Message
from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer
class ChannelTestCase(TestCase):
"""
TestCase subclass that provides easy methods for testing channels using
an in-memory backend to capture messages, and assertion methods to allow
checking of what was sent.
Inherits from TestCase, so provides per-test transactions as long as the
database backend supports it.
"""
# Customizable so users can test multi-layer setups
test_channel_aliases = [DEFAULT_CHANNEL_LAYER]
def setUp(self):
"""
Initialises in memory channel layer for the duration of the test
"""
super(ChannelTestCase, self).setUp()
self._old_layers = {}
for alias in self.test_channel_aliases:
# Swap in an in memory layer wrapper and keep the old one around
self._old_layers[alias] = channel_layers.set(
alias,
ChannelLayerWrapper(
InMemoryChannelLayer(),
alias,
channel_layers[alias].routing,
)
)
def tearDown(self):
"""
Undoes the channel rerouting
"""
for alias in self.test_channel_aliases:
# Swap in an in memory layer wrapper and keep the old one around
channel_layers.set(alias, self._old_layers[alias])
del self._old_layers
super(ChannelTestCase, self).tearDown()
def get_next_message(self, channel, alias=DEFAULT_CHANNEL_LAYER, require=False):
"""
Gets the next message that was sent to the channel during the test,
or None if no message is available.
If require is true, will fail the test if no message is received.
"""
recv_channel, content = channel_layers[alias].receive_many([channel])
if recv_channel is None:
if require:
self.fail("Expected a message on channel %s, got none" % channel)
else:
return None
return Message(content, recv_channel, channel_layers[alias])

View File

@ -6,4 +6,11 @@ DATABASES = {
}
}
CHANNEL_LAYERS = {
'default': {
'BACKEND': 'asgiref.inmemory.ChannelLayer',
'ROUTING': [],
},
}
MIDDLEWARE_CLASSES = []

View File

@ -1,10 +1,9 @@
from __future__ import unicode_literals
from django.test import SimpleTestCase
from django.http import HttpResponse
from asgiref.inmemory import ChannelLayer
from channels import Channel
from channels.handler import AsgiHandler
from channels.message import Message
from channels.tests import ChannelTestCase
class FakeAsgiHandler(AsgiHandler):
@ -24,34 +23,27 @@ class FakeAsgiHandler(AsgiHandler):
return self._response
class HandlerTests(SimpleTestCase):
class HandlerTests(ChannelTestCase):
"""
Tests that the handler works correctly and round-trips things into a
correct response.
"""
def setUp(self):
"""
Make an in memory channel layer for testing
"""
self.channel_layer = ChannelLayer()
self.make_message = lambda m, c: Message(m, c, self.channel_layer)
def test_basic(self):
"""
Tests a simple request
"""
# Make stub request and desired response
message = self.make_message({
Channel("test").send({
"reply_channel": "test",
"http_version": "1.1",
"method": "GET",
"path": b"/test/",
}, "test")
})
response = HttpResponse(b"Hi there!", content_type="text/plain")
# Run the handler
handler = FakeAsgiHandler(response)
reply_messages = list(handler(message))
reply_messages = list(handler(self.get_next_message("test", require=True)))
# Make sure we got the right number of messages
self.assertEqual(len(reply_messages), 1)
reply_message = reply_messages[0]
@ -69,16 +61,16 @@ class HandlerTests(SimpleTestCase):
Tests a large response (will need chunking)
"""
# Make stub request and desired response
message = self.make_message({
Channel("test").send({
"reply_channel": "test",
"http_version": "1.1",
"method": "GET",
"path": b"/test/",
}, "test")
})
response = HttpResponse(b"Thefirstthirtybytesisrighthereandhereistherest")
# Run the handler
handler = FakeAsgiHandler(response)
reply_messages = list(handler(message))
reply_messages = list(handler(self.get_next_message("test", require=True)))
# Make sure we got the right number of messages
self.assertEqual(len(reply_messages), 2)
# Make sure the messages look correct

View File

@ -1,36 +1,28 @@
from __future__ import unicode_literals
from django.test import SimpleTestCase
from django.utils import six
from asgiref.inmemory import ChannelLayer
from channels import Channel
from channels.tests import ChannelTestCase
from channels.handler import AsgiRequest
from channels.message import Message
class RequestTests(SimpleTestCase):
class RequestTests(ChannelTestCase):
"""
Tests that ASGI request handling correctly decodes HTTP requests.
"""
def setUp(self):
"""
Make an in memory channel layer for testing
"""
self.channel_layer = ChannelLayer()
self.make_message = lambda m, c: Message(m, c, self.channel_layer)
def test_basic(self):
"""
Tests that the handler can decode the most basic request message,
with all optional fields omitted.
"""
message = self.make_message({
Channel("test").send({
"reply_channel": "test-reply",
"http_version": "1.1",
"method": "GET",
"path": b"/test/",
}, "test")
request = AsgiRequest(message)
})
request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test/")
self.assertEqual(request.method, "GET")
self.assertFalse(request.body)
@ -48,7 +40,7 @@ class RequestTests(SimpleTestCase):
"""
Tests a more fully-featured GET request
"""
message = self.make_message({
Channel("test").send({
"reply_channel": "test",
"http_version": "1.1",
"method": "GET",
@ -60,8 +52,8 @@ class RequestTests(SimpleTestCase):
},
"client": ["10.0.0.1", 1234],
"server": ["10.0.0.2", 80],
}, "test")
request = AsgiRequest(message)
})
request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test2/")
self.assertEqual(request.method, "GET")
self.assertFalse(request.body)
@ -81,7 +73,7 @@ class RequestTests(SimpleTestCase):
"""
Tests a POST body contained within a single message.
"""
message = self.make_message({
Channel("test").send({
"reply_channel": "test",
"http_version": "1.1",
"method": "POST",
@ -93,8 +85,8 @@ class RequestTests(SimpleTestCase):
"content-type": b"application/x-www-form-urlencoded",
"content-length": b"18",
},
}, "test")
request = AsgiRequest(message)
})
request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test2/")
self.assertEqual(request.method, "POST")
self.assertEqual(request.body, b"ponies=are+awesome")
@ -111,7 +103,7 @@ class RequestTests(SimpleTestCase):
"""
Tests a POST body across multiple messages (first part in 'body').
"""
message = self.make_message({
Channel("test").send({
"reply_channel": "test",
"http_version": "1.1",
"method": "POST",
@ -123,15 +115,15 @@ class RequestTests(SimpleTestCase):
"content-type": b"application/x-www-form-urlencoded",
"content-length": b"21",
},
}, "test")
self.channel_layer.send("test-input", {
})
Channel("test-input").send({
"content": b"re=fou",
"more_content": True,
})
self.channel_layer.send("test-input", {
Channel("test-input").send({
"content": b"r+lights",
})
request = AsgiRequest(message)
request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.method, "POST")
self.assertEqual(request.body, b"there_are=four+lights")
self.assertEqual(request.META["CONTENT_TYPE"], "application/x-www-form-urlencoded")
@ -151,7 +143,7 @@ class RequestTests(SimpleTestCase):
b'FAKEPDFBYTESGOHERE' +
b'--BOUNDARY--'
)
message = self.make_message({
Channel("test").send({
"reply_channel": "test",
"http_version": "1.1",
"method": "POST",
@ -161,15 +153,15 @@ class RequestTests(SimpleTestCase):
"content-type": b"multipart/form-data; boundary=BOUNDARY",
"content-length": six.text_type(len(body)).encode("ascii"),
},
}, "test")
self.channel_layer.send("test-input", {
})
Channel("test-input").send({
"content": body[:20],
"more_content": True,
})
self.channel_layer.send("test-input", {
Channel("test-input").send({
"content": body[20:],
})
request = AsgiRequest(message)
request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.method, "POST")
self.assertEqual(len(request.body), len(body))
self.assertTrue(request.META["CONTENT_TYPE"].startswith("multipart/form-data"))
@ -181,7 +173,7 @@ class RequestTests(SimpleTestCase):
"""
Tests the body stream is emulated correctly.
"""
message = self.make_message({
Channel("test").send({
"reply_channel": "test",
"http_version": "1.1",
"method": "PUT",
@ -191,8 +183,8 @@ class RequestTests(SimpleTestCase):
"host": b"example.com",
"content-length": b"11",
},
}, "test")
request = AsgiRequest(message)
})
request = AsgiRequest(self.get_next_message("test", require=True))
self.assertEqual(request.method, "PUT")
self.assertEqual(request.read(3), b"one")
self.assertEqual(request.read(), b"twothree")