Fix tests for new non-immediate sending

This commit is contained in:
Andrew Godwin 2016-10-05 15:32:37 -07:00
parent 0b8b199212
commit 0ed04a9c06
7 changed files with 38 additions and 81 deletions

View File

@ -11,6 +11,7 @@ from ..channel import Group
from ..routing import Router, include from ..routing import Router, include
from ..asgi import channel_layers, ChannelLayerWrapper from ..asgi import channel_layers, ChannelLayerWrapper
from ..message import Message from ..message import Message
from ..signals import consumer_finished, consumer_started
from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer from asgiref.inmemory import ChannelLayer as InMemoryChannelLayer
@ -121,7 +122,11 @@ class Client(object):
match = self.channel_layer.router.match(message) match = self.channel_layer.router.match(message)
if match: if match:
consumer, kwargs = match consumer, kwargs = match
return consumer(message, **kwargs) try:
consumer_started.send(sender=self.__class__)
return consumer(message, **kwargs)
finally:
consumer_finished.send(sender=self.__class__)
elif fail_on_none: elif fail_on_none:
raise AssertionError("Can't find consumer for message %s" % message) raise AssertionError("Can't find consumer for message %s" % message)
elif fail_on_none: elif fail_on_none:

View File

@ -6,6 +6,7 @@ from channels.binding.base import CREATE, UPDATE, DELETE
from channels.binding.websockets import WebsocketBinding from channels.binding.websockets import WebsocketBinding
from channels.generic.websockets import WebsocketDemultiplexer from channels.generic.websockets import WebsocketDemultiplexer
from channels.tests import ChannelTestCase, apply_routes, HttpClient from channels.tests import ChannelTestCase, apply_routes, HttpClient
from channels.signals import consumer_finished
from channels import route, Group from channels import route, Group
User = get_user_model() User = get_user_model()
@ -33,6 +34,7 @@ class TestsBinding(ChannelTestCase):
user = User.objects.create(username='test', email='test@test.com') user = User.objects.create(username='test', email='test@test.com')
consumer_finished.send(sender=None)
received = client.receive() received = client.receive()
self.assertTrue('payload' in received) self.assertTrue('payload' in received)
self.assertTrue('action' in received['payload']) self.assertTrue('action' in received['payload'])
@ -69,7 +71,9 @@ class TestsBinding(ChannelTestCase):
def has_permission(self, user, action, pk): def has_permission(self, user, action, pk):
return True return True
# Make model and clear out pending sends
user = User.objects.create(username='test', email='test@test.com') user = User.objects.create(username='test', email='test@test.com')
consumer_finished.send(sender=None)
with apply_routes([route('test', TestBinding.consumer)]): with apply_routes([route('test', TestBinding.consumer)]):
client = HttpClient() client = HttpClient()
@ -78,6 +82,7 @@ class TestsBinding(ChannelTestCase):
user.username = 'test_new' user.username = 'test_new'
user.save() user.save()
consumer_finished.send(sender=None)
received = client.receive() received = client.receive()
self.assertTrue('payload' in received) self.assertTrue('payload' in received)
self.assertTrue('action' in received['payload']) self.assertTrue('action' in received['payload'])
@ -114,7 +119,9 @@ class TestsBinding(ChannelTestCase):
def has_permission(self, user, action, pk): def has_permission(self, user, action, pk):
return True return True
# Make model and clear out pending sends
user = User.objects.create(username='test', email='test@test.com') user = User.objects.create(username='test', email='test@test.com')
consumer_finished.send(sender=None)
with apply_routes([route('test', TestBinding.consumer)]): with apply_routes([route('test', TestBinding.consumer)]):
client = HttpClient() client = HttpClient()
@ -122,6 +129,7 @@ class TestsBinding(ChannelTestCase):
user.delete() user.delete()
consumer_finished.send(sender=None)
received = client.receive() received = client.receive()
self.assertTrue('payload' in received) self.assertTrue('payload' in received)
self.assertTrue('action' in received['payload']) self.assertTrue('action' in received['payload'])
@ -151,7 +159,7 @@ class TestsBinding(ChannelTestCase):
client.send_and_consume('websocket.connect', path='/') client.send_and_consume('websocket.connect', path='/')
# assert in group # assert in group
Group('inbound').send({'text': json.dumps({'test': 'yes'})}) Group('inbound').send({'text': json.dumps({'test': 'yes'})}, immediately=True)
self.assertEqual(client.receive(), {'test': 'yes'}) self.assertEqual(client.receive(), {'test': 'yes'})
# assert that demultiplexer stream message # assert that demultiplexer stream message

View File

@ -70,7 +70,7 @@ class GenericTests(ChannelTestCase):
def test_websockets_decorators(self): def test_websockets_decorators(self):
class WebsocketConsumer(websockets.WebsocketConsumer): class WebsocketConsumer(websockets.WebsocketConsumer):
slight_ordering = True strict_ordering = True
def connect(self, message, **kwargs): def connect(self, message, **kwargs):
self.order = message['order'] self.order = message['order']
@ -92,7 +92,7 @@ class GenericTests(ChannelTestCase):
self.send(text=message.get('order')) self.send(text=message.get('order'))
routes = [ routes = [
WebsocketConsumer.as_route(attrs={'slight_ordering': True}, path='^/path$'), WebsocketConsumer.as_route(attrs={"strict_ordering": True}, path='^/path$'),
WebsocketConsumer.as_route(path='^/path/2$'), WebsocketConsumer.as_route(path='^/path/2$'),
] ]

View File

@ -13,6 +13,7 @@ from six import BytesIO
from channels import Channel from channels import Channel
from channels.handler import AsgiHandler from channels.handler import AsgiHandler
from channels.tests import ChannelTestCase from channels.tests import ChannelTestCase
from channels.signals import consumer_finished
class FakeAsgiHandler(AsgiHandler): class FakeAsgiHandler(AsgiHandler):
@ -26,6 +27,7 @@ class FakeAsgiHandler(AsgiHandler):
def __init__(self, response): def __init__(self, response):
assert isinstance(response, (HttpResponse, StreamingHttpResponse)) assert isinstance(response, (HttpResponse, StreamingHttpResponse))
self._response = response self._response = response
consumer_finished.send(sender=self.__class__)
super(FakeAsgiHandler, self).__init__() super(FakeAsgiHandler, self).__init__()
def get_response(self, request): def get_response(self, request):

View File

@ -22,7 +22,7 @@ class RequestTests(ChannelTestCase):
"http_version": "1.1", "http_version": "1.1",
"method": "GET", "method": "GET",
"path": "/test/", "path": "/test/",
}) }, immediately=True)
request = AsgiRequest(self.get_next_message("test")) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test/") self.assertEqual(request.path, "/test/")
self.assertEqual(request.method, "GET") self.assertEqual(request.method, "GET")
@ -53,7 +53,7 @@ class RequestTests(ChannelTestCase):
}, },
"client": ["10.0.0.1", 1234], "client": ["10.0.0.1", 1234],
"server": ["10.0.0.2", 80], "server": ["10.0.0.2", 80],
}) }, immediately=True)
request = AsgiRequest(self.get_next_message("test")) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test2/") self.assertEqual(request.path, "/test2/")
self.assertEqual(request.method, "GET") self.assertEqual(request.method, "GET")
@ -86,7 +86,7 @@ class RequestTests(ChannelTestCase):
"content-type": b"application/x-www-form-urlencoded", "content-type": b"application/x-www-form-urlencoded",
"content-length": b"18", "content-length": b"18",
}, },
}) }, immediately=True)
request = AsgiRequest(self.get_next_message("test")) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.path, "/test2/") self.assertEqual(request.path, "/test2/")
self.assertEqual(request.method, "POST") self.assertEqual(request.method, "POST")
@ -116,14 +116,14 @@ class RequestTests(ChannelTestCase):
"content-type": b"application/x-www-form-urlencoded", "content-type": b"application/x-www-form-urlencoded",
"content-length": b"21", "content-length": b"21",
}, },
}) }, immediately=True)
Channel("test-input").send({ Channel("test-input").send({
"content": b"re=fou", "content": b"re=fou",
"more_content": True, "more_content": True,
}) }, immediately=True)
Channel("test-input").send({ Channel("test-input").send({
"content": b"r+lights", "content": b"r+lights",
}) }, immediately=True)
request = AsgiRequest(self.get_next_message("test")) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.method, "POST") self.assertEqual(request.method, "POST")
self.assertEqual(request.body, b"there_are=four+lights") self.assertEqual(request.body, b"there_are=four+lights")
@ -154,14 +154,14 @@ class RequestTests(ChannelTestCase):
"content-type": b"multipart/form-data; boundary=BOUNDARY", "content-type": b"multipart/form-data; boundary=BOUNDARY",
"content-length": six.text_type(len(body)).encode("ascii"), "content-length": six.text_type(len(body)).encode("ascii"),
}, },
}) }, immediately=True)
Channel("test-input").send({ Channel("test-input").send({
"content": body[:20], "content": body[:20],
"more_content": True, "more_content": True,
}) }, immediately=True)
Channel("test-input").send({ Channel("test-input").send({
"content": body[20:], "content": body[20:],
}) }, immediately=True)
request = AsgiRequest(self.get_next_message("test")) request = AsgiRequest(self.get_next_message("test"))
self.assertEqual(request.method, "POST") self.assertEqual(request.method, "POST")
self.assertEqual(len(request.body), len(body)) self.assertEqual(len(request.body), len(body))
@ -184,7 +184,7 @@ class RequestTests(ChannelTestCase):
"host": b"example.com", "host": b"example.com",
"content-length": b"11", "content-length": b"11",
}, },
}) }, immediately=True)
request = AsgiRequest(self.get_next_message("test", require=True)) request = AsgiRequest(self.get_next_message("test", require=True))
self.assertEqual(request.method, "PUT") self.assertEqual(request.method, "PUT")
self.assertEqual(request.read(3), b"one") self.assertEqual(request.read(3), b"one")
@ -206,12 +206,12 @@ class RequestTests(ChannelTestCase):
"content-type": b"application/x-www-form-urlencoded", "content-type": b"application/x-www-form-urlencoded",
"content-length": b"21", "content-length": b"21",
}, },
}) }, immediately=True)
# Say there's more content, but never provide it! Muahahaha! # Say there's more content, but never provide it! Muahahaha!
Channel("test-input").send({ Channel("test-input").send({
"content": b"re=fou", "content": b"re=fou",
"more_content": True, "more_content": True,
}) }, immediately=True)
class VeryImpatientRequest(AsgiRequest): class VeryImpatientRequest(AsgiRequest):
body_receive_timeout = 0 body_receive_timeout = 0
@ -235,9 +235,9 @@ class RequestTests(ChannelTestCase):
"content-type": b"application/x-www-form-urlencoded", "content-type": b"application/x-www-form-urlencoded",
"content-length": b"21", "content-length": b"21",
}, },
}) }, immediately=True)
Channel("test-input").send({ Channel("test-input").send({
"closed": True, "closed": True,
}) }, immediately=True)
with self.assertRaises(RequestAborted): with self.assertRaises(RequestAborted):
AsgiRequest(self.get_next_message("test")) AsgiRequest(self.get_next_message("test"))

View File

@ -141,65 +141,7 @@ class SessionTests(ChannelTestCase):
# It should hydrate the http_session # It should hydrate the http_session
self.assertEqual(message2.http_session.session_key, session.session_key) self.assertEqual(message2.http_session.session_key, session.session_key)
def test_enforce_ordering_slight(self): def test_enforce_ordering(self):
"""
Tests that slight mode of enforce_ordering works
"""
# Construct messages to send
message0 = Message(
{"reply_channel": "test-reply-a", "order": 0},
"websocket.connect",
channel_layers[DEFAULT_CHANNEL_LAYER]
)
message1 = Message(
{"reply_channel": "test-reply-a", "order": 1},
"websocket.receive",
channel_layers[DEFAULT_CHANNEL_LAYER]
)
message2 = Message(
{"reply_channel": "test-reply-a", "order": 2},
"websocket.receive",
channel_layers[DEFAULT_CHANNEL_LAYER]
)
# Run them in an acceptable slight order
@enforce_ordering(slight=True)
def inner(message):
pass
inner(message0)
inner(message2)
inner(message1)
# Ensure wait channel is empty
wait_channel = "__wait__.%s" % "test-reply-a"
next_message = self.get_next_message(wait_channel)
self.assertEqual(next_message, None)
def test_enforce_ordering_slight_fail(self):
"""
Tests that slight mode of enforce_ordering fails on bad ordering
"""
# Construct messages to send
message2 = Message(
{"reply_channel": "test-reply-e", "order": 2},
"websocket.receive",
channel_layers[DEFAULT_CHANNEL_LAYER]
)
# Run them in an acceptable strict order
@enforce_ordering(slight=True)
def inner(message):
pass
inner(message2)
# Ensure wait channel is not empty
wait_channel = "__wait__.%s" % "test-reply-e"
next_message = self.get_next_message(wait_channel)
self.assertNotEqual(next_message, None)
def test_enforce_ordering_strict(self):
""" """
Tests that strict mode of enforce_ordering works Tests that strict mode of enforce_ordering works
""" """
@ -234,7 +176,7 @@ class SessionTests(ChannelTestCase):
next_message = self.get_next_message(wait_channel) next_message = self.get_next_message(wait_channel)
self.assertEqual(next_message, None) self.assertEqual(next_message, None)
def test_enforce_ordering_strict_fail(self): def test_enforce_ordering_fail(self):
""" """
Tests that strict mode of enforce_ordering fails on bad ordering Tests that strict mode of enforce_ordering fails on bad ordering
""" """
@ -273,7 +215,7 @@ class SessionTests(ChannelTestCase):
channel_layers[DEFAULT_CHANNEL_LAYER] channel_layers[DEFAULT_CHANNEL_LAYER]
) )
@enforce_ordering(slight=True) @enforce_ordering
def inner(message): def inner(message):
pass pass

View File

@ -68,7 +68,7 @@ class WorkerTests(ChannelTestCase):
if _consumer._call_count == 1: if _consumer._call_count == 1:
raise ConsumeLater() raise ConsumeLater()
Channel('test').send({'test': 'test'}) Channel('test').send({'test': 'test'}, immediately=True)
channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER] channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
channel_layer.router.add_route(route('test', _consumer)) channel_layer.router.add_route(route('test', _consumer))
old_send = channel_layer.send old_send = channel_layer.send
@ -83,7 +83,7 @@ class WorkerTests(ChannelTestCase):
def test_normal_run(self): def test_normal_run(self):
consumer = mock.Mock() consumer = mock.Mock()
Channel('test').send({'test': 'test'}) Channel('test').send({'test': 'test'}, immediately=True)
channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER] channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER]
channel_layer.router.add_route(route('test', consumer)) channel_layer.router.add_route(route('test', consumer))
old_send = channel_layer.send old_send = channel_layer.send