Accept Connection at WebsocketConsumer (#467)

* Added accept at default behavior for websocket generic cbv and pass message instead of dict

* Fix flake8

* Use HttpClient Instead of Client

* Fix lsort
This commit is contained in:
Krukov D 2017-01-09 21:08:00 +03:00 committed by Andrew Godwin
parent 827fcd25b1
commit 8a93dfc401
5 changed files with 47 additions and 56 deletions

View File

@ -87,18 +87,17 @@ class WebsocketBinding(Binding):
# Only allow received packets through further. # Only allow received packets through further.
if message.channel.name != "websocket.receive": if message.channel.name != "websocket.receive":
return return
# Call superclass, unpacking the payload in the process super(WebsocketBinding, cls).trigger_inbound(message, **kwargs)
payload = json.loads(message['text'])
super(WebsocketBinding, cls).trigger_inbound(payload, **kwargs)
def deserialize(self, message): def deserialize(self, message):
""" """
You must hook this up behind a Deserializer, so we expect the JSON You must hook this up behind a Deserializer, so we expect the JSON
already dealt with. already dealt with.
""" """
action = message['action'] body = json.loads(message['text'])
pk = message.get('pk', None) action = body['action']
data = message.get('data', None) pk = body.get('pk', None)
data = body.get('data', None)
return action, pk, data return action, pk, data
def _hydrate(self, pk, data): def _hydrate(self, pk, data):

View File

@ -72,7 +72,7 @@ class WebsocketConsumer(BaseConsumer):
""" """
Called when a WebSocket connection is opened. Called when a WebSocket connection is opened.
""" """
pass self.message.reply_channel.send({"accept": True})
def raw_receive(self, message, **kwargs): def raw_receive(self, message, **kwargs):
""" """
@ -220,6 +220,7 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer):
def connect(self, message, **kwargs): def connect(self, message, **kwargs):
"""Forward connection to all consumers.""" """Forward connection to all consumers."""
self.message.reply_channel.send({"accept": True})
for stream, consumer in self.consumers.items(): for stream, consumer in self.consumers.items():
kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel) kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel)
consumer(message, **kwargs) consumer(message, **kwargs)

View File

@ -81,5 +81,6 @@ class PendingMessageStore(object):
sender.send(message, immediately=True) sender.send(message, immediately=True)
self.threadlocal.messages = [] self.threadlocal.messages = []
pending_message_store = PendingMessageStore() pending_message_store = PendingMessageStore()
consumer_finished.connect(pending_message_store.send_and_flush) consumer_finished.connect(pending_message_store.send_and_flush)

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
import copy import copy
import json import json
@ -66,9 +67,9 @@ class HttpClient(Client):
Return text content of a message for client channel and decoding it if json kwarg is set Return text content of a message for client channel and decoding it if json kwarg is set
""" """
content = super(HttpClient, self).receive() content = super(HttpClient, self).receive()
if content and json: if content and json and 'text' in content and isinstance(content['text'], six.string_types):
return json_module.loads(content['text']) return json_module.loads(content['text'])
return content['text'] if content else None return content.get('text', content) if content else None
def send(self, to, content={}, text=None, path='/'): def send(self, to, content={}, text=None, path='/'):
""" """
@ -87,12 +88,20 @@ class HttpClient(Client):
content['text'] = text content['text'] = text
self.channel_layer.send(to, content) self.channel_layer.send(to, content)
def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True): def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True, check_accept=True):
""" """
Reproduce full life cycle of the message Reproduce full life cycle of the message
""" """
self.send(channel, content, text, path) self.send(channel, content, text, path)
return self.consume(channel, fail_on_none=fail_on_none) return self.consume(channel, fail_on_none=fail_on_none, check_accept=check_accept)
def consume(self, channel, fail_on_none=True, check_accept=True):
result = super(HttpClient, self).consume(channel, fail_on_none=fail_on_none)
if channel == "websocket.connect" and check_accept:
received = self.receive(json=False)
if received != {"accept": True}:
raise AssertionError("Connection rejected: %s != '{accept: True}'" % received)
return result
def login(self, **credentials): def login(self, **credentials):
""" """

View File

@ -1,13 +1,11 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import json
from django.test import override_settings from django.test import override_settings
from channels import route_class from channels import route_class
from channels.exceptions import SendNotAvailableOnDemultiplexer from channels.exceptions import SendNotAvailableOnDemultiplexer
from channels.generic import BaseConsumer, websockets from channels.generic import BaseConsumer, websockets
from channels.tests import ChannelTestCase, Client, apply_routes from channels.tests import ChannelTestCase, Client, HttpClient, apply_routes
@override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache") @override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache")
@ -92,6 +90,7 @@ class GenericTests(ChannelTestCase):
class WebsocketConsumer(websockets.WebsocketConsumer): class WebsocketConsumer(websockets.WebsocketConsumer):
def connect(self, message, **kwargs): def connect(self, message, **kwargs):
self.message.reply_channel.send({'accept': True})
self.send(text=message.get('order')) self.send(text=message.get('order'))
routes = [ routes = [
@ -103,18 +102,18 @@ class GenericTests(ChannelTestCase):
self.assertIs(routes[1].consumer, WebsocketConsumer) self.assertIs(routes[1].consumer, WebsocketConsumer)
with apply_routes(routes): with apply_routes(routes):
client = Client() client = HttpClient()
client.send('websocket.connect', {'path': '/path', 'order': 1}) client.send('websocket.connect', {'path': '/path', 'order': 1})
client.send('websocket.connect', {'path': '/path', 'order': 0}) client.send('websocket.connect', {'path': '/path', 'order': 0})
client.consume('websocket.connect', check_accept=False)
client.consume('websocket.connect') client.consume('websocket.connect')
self.assertEqual(client.receive(json=False), 0)
client.consume('websocket.connect') client.consume('websocket.connect')
client.consume('websocket.connect') self.assertEqual(client.receive(json=False), 1)
self.assertEqual(client.receive(), {'text': 0})
self.assertEqual(client.receive(), {'text': 1})
client.send_and_consume('websocket.connect', {'path': '/path/2', 'order': 'next'}) client.send_and_consume('websocket.connect', {'path': '/path/2', 'order': 'next'})
self.assertEqual(client.receive(), {'text': 'next'}) self.assertEqual(client.receive(json=False), 'next')
def test_as_route_method(self): def test_as_route_method(self):
class WebsocketConsumer(BaseConsumer): class WebsocketConsumer(BaseConsumer):
@ -154,41 +153,29 @@ class GenericTests(ChannelTestCase):
"mystream": MyWebsocketConsumer "mystream": MyWebsocketConsumer
} }
with apply_routes([ with apply_routes([route_class(Demultiplexer, path='/path/(?P<id>\d+)')]):
route_class(Demultiplexer, path='/path/(?P<id>\d+)'), client = HttpClient()
route_class(MyWebsocketConsumer),
]):
client = Client()
client.send_and_consume('websocket.connect', {'path': '/path/1'}) client.send_and_consume('websocket.connect', path='/path/1')
self.assertEqual(client.receive(), { self.assertEqual(client.receive(), {
"text": json.dumps({
"stream": "mystream", "stream": "mystream",
"payload": {"id": "1"}, "payload": {"id": "1"},
}) })
})
client.send_and_consume('websocket.receive', { client.send_and_consume('websocket.receive', text={
'path': '/path/1',
'text': json.dumps({
"stream": "mystream", "stream": "mystream",
"payload": {"text_field": "mytext"} "payload": {"text_field": "mytext"},
}) }, path='/path/1')
})
self.assertEqual(client.receive(), { self.assertEqual(client.receive(), {
"text": json.dumps({
"stream": "mystream", "stream": "mystream",
"payload": {"text_field": "mytext"}, "payload": {"text_field": "mytext"},
}) })
})
client.send_and_consume('websocket.disconnect', {'path': '/path/1'}) client.send_and_consume('websocket.disconnect', path='/path/1')
self.assertEqual(client.receive(), { self.assertEqual(client.receive(), {
"text": json.dumps({
"stream": "mystream", "stream": "mystream",
"payload": {"id": "1"}, "payload": {"id": "1"},
}) })
})
def test_websocket_demultiplexer_send(self): def test_websocket_demultiplexer_send(self):
@ -202,19 +189,13 @@ class GenericTests(ChannelTestCase):
"mystream": MyWebsocketConsumer "mystream": MyWebsocketConsumer
} }
with apply_routes([ with apply_routes([route_class(Demultiplexer, path='/path/(?P<id>\d+)')]):
route_class(Demultiplexer, path='/path/(?P<id>\d+)'), client = HttpClient()
route_class(MyWebsocketConsumer),
]):
client = Client()
with self.assertRaises(SendNotAvailableOnDemultiplexer): with self.assertRaises(SendNotAvailableOnDemultiplexer):
client.send_and_consume('websocket.receive', { client.send_and_consume('websocket.receive', path='/path/1', text={
'path': '/path/1',
'text': json.dumps({
"stream": "mystream", "stream": "mystream",
"payload": {"text_field": "mytext"} "payload": {"text_field": "mytext"},
})
}) })
client.receive() client.receive()