From 8a93dfc4019f71c88bb5c6c256af35adff275fbb Mon Sep 17 00:00:00 2001 From: Krukov D Date: Mon, 9 Jan 2017 21:08:00 +0300 Subject: [PATCH] Accept Connection at WebsocketConsumer (#467) * Added accept at default behavior for websocket generic cbv and pass message instead of dict * Fix flake8 * Use HttpClient Instead of Client * Fix lsort --- channels/binding/websockets.py | 11 +++--- channels/generic/websockets.py | 3 +- channels/message.py | 1 + channels/tests/http.py | 17 ++++++-- channels/tests/test_generic.py | 71 +++++++++++++--------------------- 5 files changed, 47 insertions(+), 56 deletions(-) diff --git a/channels/binding/websockets.py b/channels/binding/websockets.py index cd910ca..6017150 100644 --- a/channels/binding/websockets.py +++ b/channels/binding/websockets.py @@ -87,18 +87,17 @@ class WebsocketBinding(Binding): # Only allow received packets through further. if message.channel.name != "websocket.receive": return - # Call superclass, unpacking the payload in the process - payload = json.loads(message['text']) - super(WebsocketBinding, cls).trigger_inbound(payload, **kwargs) + super(WebsocketBinding, cls).trigger_inbound(message, **kwargs) def deserialize(self, message): """ You must hook this up behind a Deserializer, so we expect the JSON already dealt with. """ - action = message['action'] - pk = message.get('pk', None) - data = message.get('data', None) + body = json.loads(message['text']) + action = body['action'] + pk = body.get('pk', None) + data = body.get('data', None) return action, pk, data def _hydrate(self, pk, data): diff --git a/channels/generic/websockets.py b/channels/generic/websockets.py index 32ebe68..555b165 100644 --- a/channels/generic/websockets.py +++ b/channels/generic/websockets.py @@ -72,7 +72,7 @@ class WebsocketConsumer(BaseConsumer): """ Called when a WebSocket connection is opened. """ - pass + self.message.reply_channel.send({"accept": True}) def raw_receive(self, message, **kwargs): """ @@ -220,6 +220,7 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer): def connect(self, message, **kwargs): """Forward connection to all consumers.""" + self.message.reply_channel.send({"accept": True}) for stream, consumer in self.consumers.items(): kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel) consumer(message, **kwargs) diff --git a/channels/message.py b/channels/message.py index 97e67a3..6a4d3f3 100644 --- a/channels/message.py +++ b/channels/message.py @@ -81,5 +81,6 @@ class PendingMessageStore(object): sender.send(message, immediately=True) self.threadlocal.messages = [] + pending_message_store = PendingMessageStore() consumer_finished.connect(pending_message_store.send_and_flush) diff --git a/channels/tests/http.py b/channels/tests/http.py index 9759063..5c3b4bb 100644 --- a/channels/tests/http.py +++ b/channels/tests/http.py @@ -1,3 +1,4 @@ +from __future__ import unicode_literals import copy import json @@ -66,9 +67,9 @@ class HttpClient(Client): Return text content of a message for client channel and decoding it if json kwarg is set """ content = super(HttpClient, self).receive() - if content and json: + if content and json and 'text' in content and isinstance(content['text'], six.string_types): return json_module.loads(content['text']) - return content['text'] if content else None + return content.get('text', content) if content else None def send(self, to, content={}, text=None, path='/'): """ @@ -87,12 +88,20 @@ class HttpClient(Client): content['text'] = text self.channel_layer.send(to, content) - def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True): + def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True, check_accept=True): """ Reproduce full life cycle of the message """ self.send(channel, content, text, path) - return self.consume(channel, fail_on_none=fail_on_none) + return self.consume(channel, fail_on_none=fail_on_none, check_accept=check_accept) + + def consume(self, channel, fail_on_none=True, check_accept=True): + result = super(HttpClient, self).consume(channel, fail_on_none=fail_on_none) + if channel == "websocket.connect" and check_accept: + received = self.receive(json=False) + if received != {"accept": True}: + raise AssertionError("Connection rejected: %s != '{accept: True}'" % received) + return result def login(self, **credentials): """ diff --git a/channels/tests/test_generic.py b/channels/tests/test_generic.py index 40ade79..2938598 100644 --- a/channels/tests/test_generic.py +++ b/channels/tests/test_generic.py @@ -1,13 +1,11 @@ from __future__ import unicode_literals -import json - from django.test import override_settings from channels import route_class from channels.exceptions import SendNotAvailableOnDemultiplexer from channels.generic import BaseConsumer, websockets -from channels.tests import ChannelTestCase, Client, apply_routes +from channels.tests import ChannelTestCase, Client, HttpClient, apply_routes @override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache") @@ -92,6 +90,7 @@ class GenericTests(ChannelTestCase): class WebsocketConsumer(websockets.WebsocketConsumer): def connect(self, message, **kwargs): + self.message.reply_channel.send({'accept': True}) self.send(text=message.get('order')) routes = [ @@ -103,18 +102,18 @@ class GenericTests(ChannelTestCase): self.assertIs(routes[1].consumer, WebsocketConsumer) with apply_routes(routes): - client = Client() + client = HttpClient() client.send('websocket.connect', {'path': '/path', 'order': 1}) client.send('websocket.connect', {'path': '/path', 'order': 0}) + client.consume('websocket.connect', check_accept=False) client.consume('websocket.connect') + self.assertEqual(client.receive(json=False), 0) client.consume('websocket.connect') - client.consume('websocket.connect') - self.assertEqual(client.receive(), {'text': 0}) - self.assertEqual(client.receive(), {'text': 1}) + self.assertEqual(client.receive(json=False), 1) client.send_and_consume('websocket.connect', {'path': '/path/2', 'order': 'next'}) - self.assertEqual(client.receive(), {'text': 'next'}) + self.assertEqual(client.receive(json=False), 'next') def test_as_route_method(self): class WebsocketConsumer(BaseConsumer): @@ -154,40 +153,28 @@ class GenericTests(ChannelTestCase): "mystream": MyWebsocketConsumer } - with apply_routes([ - route_class(Demultiplexer, path='/path/(?P\d+)'), - route_class(MyWebsocketConsumer), - ]): - client = Client() + with apply_routes([route_class(Demultiplexer, path='/path/(?P\d+)')]): + client = HttpClient() - client.send_and_consume('websocket.connect', {'path': '/path/1'}) + client.send_and_consume('websocket.connect', path='/path/1') self.assertEqual(client.receive(), { - "text": json.dumps({ - "stream": "mystream", - "payload": {"id": "1"}, - }) + "stream": "mystream", + "payload": {"id": "1"}, }) - client.send_and_consume('websocket.receive', { - 'path': '/path/1', - 'text': json.dumps({ - "stream": "mystream", - "payload": {"text_field": "mytext"} - }) - }) + client.send_and_consume('websocket.receive', text={ + "stream": "mystream", + "payload": {"text_field": "mytext"}, + }, path='/path/1') self.assertEqual(client.receive(), { - "text": json.dumps({ - "stream": "mystream", - "payload": {"text_field": "mytext"}, - }) + "stream": "mystream", + "payload": {"text_field": "mytext"}, }) - client.send_and_consume('websocket.disconnect', {'path': '/path/1'}) + client.send_and_consume('websocket.disconnect', path='/path/1') self.assertEqual(client.receive(), { - "text": json.dumps({ - "stream": "mystream", - "payload": {"id": "1"}, - }) + "stream": "mystream", + "payload": {"id": "1"}, }) def test_websocket_demultiplexer_send(self): @@ -202,19 +189,13 @@ class GenericTests(ChannelTestCase): "mystream": MyWebsocketConsumer } - with apply_routes([ - route_class(Demultiplexer, path='/path/(?P\d+)'), - route_class(MyWebsocketConsumer), - ]): - client = Client() + with apply_routes([route_class(Demultiplexer, path='/path/(?P\d+)')]): + client = HttpClient() with self.assertRaises(SendNotAvailableOnDemultiplexer): - client.send_and_consume('websocket.receive', { - 'path': '/path/1', - 'text': json.dumps({ - "stream": "mystream", - "payload": {"text_field": "mytext"} - }) + client.send_and_consume('websocket.receive', path='/path/1', text={ + "stream": "mystream", + "payload": {"text_field": "mytext"}, }) client.receive()