Accept Connection at WebsocketConsumer (#467)

* Added accept at default behavior for websocket generic cbv and pass message instead of dict

* Fix flake8

* Use HttpClient Instead of Client

* Fix lsort
This commit is contained in:
Krukov D 2017-01-09 21:08:00 +03:00 committed by Andrew Godwin
parent 827fcd25b1
commit 8a93dfc401
5 changed files with 47 additions and 56 deletions

View File

@ -87,18 +87,17 @@ class WebsocketBinding(Binding):
# Only allow received packets through further.
if message.channel.name != "websocket.receive":
return
# Call superclass, unpacking the payload in the process
payload = json.loads(message['text'])
super(WebsocketBinding, cls).trigger_inbound(payload, **kwargs)
super(WebsocketBinding, cls).trigger_inbound(message, **kwargs)
def deserialize(self, message):
"""
You must hook this up behind a Deserializer, so we expect the JSON
already dealt with.
"""
action = message['action']
pk = message.get('pk', None)
data = message.get('data', None)
body = json.loads(message['text'])
action = body['action']
pk = body.get('pk', None)
data = body.get('data', None)
return action, pk, data
def _hydrate(self, pk, data):

View File

@ -72,7 +72,7 @@ class WebsocketConsumer(BaseConsumer):
"""
Called when a WebSocket connection is opened.
"""
pass
self.message.reply_channel.send({"accept": True})
def raw_receive(self, message, **kwargs):
"""
@ -220,6 +220,7 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer):
def connect(self, message, **kwargs):
"""Forward connection to all consumers."""
self.message.reply_channel.send({"accept": True})
for stream, consumer in self.consumers.items():
kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel)
consumer(message, **kwargs)

View File

@ -81,5 +81,6 @@ class PendingMessageStore(object):
sender.send(message, immediately=True)
self.threadlocal.messages = []
pending_message_store = PendingMessageStore()
consumer_finished.connect(pending_message_store.send_and_flush)

View File

@ -1,3 +1,4 @@
from __future__ import unicode_literals
import copy
import json
@ -66,9 +67,9 @@ class HttpClient(Client):
Return text content of a message for client channel and decoding it if json kwarg is set
"""
content = super(HttpClient, self).receive()
if content and json:
if content and json and 'text' in content and isinstance(content['text'], six.string_types):
return json_module.loads(content['text'])
return content['text'] if content else None
return content.get('text', content) if content else None
def send(self, to, content={}, text=None, path='/'):
"""
@ -87,12 +88,20 @@ class HttpClient(Client):
content['text'] = text
self.channel_layer.send(to, content)
def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True):
def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True, check_accept=True):
"""
Reproduce full life cycle of the message
"""
self.send(channel, content, text, path)
return self.consume(channel, fail_on_none=fail_on_none)
return self.consume(channel, fail_on_none=fail_on_none, check_accept=check_accept)
def consume(self, channel, fail_on_none=True, check_accept=True):
result = super(HttpClient, self).consume(channel, fail_on_none=fail_on_none)
if channel == "websocket.connect" and check_accept:
received = self.receive(json=False)
if received != {"accept": True}:
raise AssertionError("Connection rejected: %s != '{accept: True}'" % received)
return result
def login(self, **credentials):
"""

View File

@ -1,13 +1,11 @@
from __future__ import unicode_literals
import json
from django.test import override_settings
from channels import route_class
from channels.exceptions import SendNotAvailableOnDemultiplexer
from channels.generic import BaseConsumer, websockets
from channels.tests import ChannelTestCase, Client, apply_routes
from channels.tests import ChannelTestCase, Client, HttpClient, apply_routes
@override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache")
@ -92,6 +90,7 @@ class GenericTests(ChannelTestCase):
class WebsocketConsumer(websockets.WebsocketConsumer):
def connect(self, message, **kwargs):
self.message.reply_channel.send({'accept': True})
self.send(text=message.get('order'))
routes = [
@ -103,18 +102,18 @@ class GenericTests(ChannelTestCase):
self.assertIs(routes[1].consumer, WebsocketConsumer)
with apply_routes(routes):
client = Client()
client = HttpClient()
client.send('websocket.connect', {'path': '/path', 'order': 1})
client.send('websocket.connect', {'path': '/path', 'order': 0})
client.consume('websocket.connect', check_accept=False)
client.consume('websocket.connect')
self.assertEqual(client.receive(json=False), 0)
client.consume('websocket.connect')
client.consume('websocket.connect')
self.assertEqual(client.receive(), {'text': 0})
self.assertEqual(client.receive(), {'text': 1})
self.assertEqual(client.receive(json=False), 1)
client.send_and_consume('websocket.connect', {'path': '/path/2', 'order': 'next'})
self.assertEqual(client.receive(), {'text': 'next'})
self.assertEqual(client.receive(json=False), 'next')
def test_as_route_method(self):
class WebsocketConsumer(BaseConsumer):
@ -154,40 +153,28 @@ class GenericTests(ChannelTestCase):
"mystream": MyWebsocketConsumer
}
with apply_routes([
route_class(Demultiplexer, path='/path/(?P<id>\d+)'),
route_class(MyWebsocketConsumer),
]):
client = Client()
with apply_routes([route_class(Demultiplexer, path='/path/(?P<id>\d+)')]):
client = HttpClient()
client.send_and_consume('websocket.connect', {'path': '/path/1'})
client.send_and_consume('websocket.connect', path='/path/1')
self.assertEqual(client.receive(), {
"text": json.dumps({
"stream": "mystream",
"payload": {"id": "1"},
})
"stream": "mystream",
"payload": {"id": "1"},
})
client.send_and_consume('websocket.receive', {
'path': '/path/1',
'text': json.dumps({
"stream": "mystream",
"payload": {"text_field": "mytext"}
})
})
client.send_and_consume('websocket.receive', text={
"stream": "mystream",
"payload": {"text_field": "mytext"},
}, path='/path/1')
self.assertEqual(client.receive(), {
"text": json.dumps({
"stream": "mystream",
"payload": {"text_field": "mytext"},
})
"stream": "mystream",
"payload": {"text_field": "mytext"},
})
client.send_and_consume('websocket.disconnect', {'path': '/path/1'})
client.send_and_consume('websocket.disconnect', path='/path/1')
self.assertEqual(client.receive(), {
"text": json.dumps({
"stream": "mystream",
"payload": {"id": "1"},
})
"stream": "mystream",
"payload": {"id": "1"},
})
def test_websocket_demultiplexer_send(self):
@ -202,19 +189,13 @@ class GenericTests(ChannelTestCase):
"mystream": MyWebsocketConsumer
}
with apply_routes([
route_class(Demultiplexer, path='/path/(?P<id>\d+)'),
route_class(MyWebsocketConsumer),
]):
client = Client()
with apply_routes([route_class(Demultiplexer, path='/path/(?P<id>\d+)')]):
client = HttpClient()
with self.assertRaises(SendNotAvailableOnDemultiplexer):
client.send_and_consume('websocket.receive', {
'path': '/path/1',
'text': json.dumps({
"stream": "mystream",
"payload": {"text_field": "mytext"}
})
client.send_and_consume('websocket.receive', path='/path/1', text={
"stream": "mystream",
"payload": {"text_field": "mytext"},
})
client.receive()