mirror of
https://github.com/django/daphne.git
synced 2025-07-13 09:22:17 +03:00
Accept Connection at WebsocketConsumer (#467)
* Added accept at default behavior for websocket generic cbv and pass message instead of dict * Fix flake8 * Use HttpClient Instead of Client * Fix lsort
This commit is contained in:
parent
827fcd25b1
commit
8a93dfc401
|
@ -87,18 +87,17 @@ class WebsocketBinding(Binding):
|
||||||
# Only allow received packets through further.
|
# Only allow received packets through further.
|
||||||
if message.channel.name != "websocket.receive":
|
if message.channel.name != "websocket.receive":
|
||||||
return
|
return
|
||||||
# Call superclass, unpacking the payload in the process
|
super(WebsocketBinding, cls).trigger_inbound(message, **kwargs)
|
||||||
payload = json.loads(message['text'])
|
|
||||||
super(WebsocketBinding, cls).trigger_inbound(payload, **kwargs)
|
|
||||||
|
|
||||||
def deserialize(self, message):
|
def deserialize(self, message):
|
||||||
"""
|
"""
|
||||||
You must hook this up behind a Deserializer, so we expect the JSON
|
You must hook this up behind a Deserializer, so we expect the JSON
|
||||||
already dealt with.
|
already dealt with.
|
||||||
"""
|
"""
|
||||||
action = message['action']
|
body = json.loads(message['text'])
|
||||||
pk = message.get('pk', None)
|
action = body['action']
|
||||||
data = message.get('data', None)
|
pk = body.get('pk', None)
|
||||||
|
data = body.get('data', None)
|
||||||
return action, pk, data
|
return action, pk, data
|
||||||
|
|
||||||
def _hydrate(self, pk, data):
|
def _hydrate(self, pk, data):
|
||||||
|
|
|
@ -72,7 +72,7 @@ class WebsocketConsumer(BaseConsumer):
|
||||||
"""
|
"""
|
||||||
Called when a WebSocket connection is opened.
|
Called when a WebSocket connection is opened.
|
||||||
"""
|
"""
|
||||||
pass
|
self.message.reply_channel.send({"accept": True})
|
||||||
|
|
||||||
def raw_receive(self, message, **kwargs):
|
def raw_receive(self, message, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -220,6 +220,7 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer):
|
||||||
|
|
||||||
def connect(self, message, **kwargs):
|
def connect(self, message, **kwargs):
|
||||||
"""Forward connection to all consumers."""
|
"""Forward connection to all consumers."""
|
||||||
|
self.message.reply_channel.send({"accept": True})
|
||||||
for stream, consumer in self.consumers.items():
|
for stream, consumer in self.consumers.items():
|
||||||
kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel)
|
kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel)
|
||||||
consumer(message, **kwargs)
|
consumer(message, **kwargs)
|
||||||
|
|
|
@ -81,5 +81,6 @@ class PendingMessageStore(object):
|
||||||
sender.send(message, immediately=True)
|
sender.send(message, immediately=True)
|
||||||
self.threadlocal.messages = []
|
self.threadlocal.messages = []
|
||||||
|
|
||||||
|
|
||||||
pending_message_store = PendingMessageStore()
|
pending_message_store = PendingMessageStore()
|
||||||
consumer_finished.connect(pending_message_store.send_and_flush)
|
consumer_finished.connect(pending_message_store.send_and_flush)
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
from __future__ import unicode_literals
|
||||||
import copy
|
import copy
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
@ -66,9 +67,9 @@ class HttpClient(Client):
|
||||||
Return text content of a message for client channel and decoding it if json kwarg is set
|
Return text content of a message for client channel and decoding it if json kwarg is set
|
||||||
"""
|
"""
|
||||||
content = super(HttpClient, self).receive()
|
content = super(HttpClient, self).receive()
|
||||||
if content and json:
|
if content and json and 'text' in content and isinstance(content['text'], six.string_types):
|
||||||
return json_module.loads(content['text'])
|
return json_module.loads(content['text'])
|
||||||
return content['text'] if content else None
|
return content.get('text', content) if content else None
|
||||||
|
|
||||||
def send(self, to, content={}, text=None, path='/'):
|
def send(self, to, content={}, text=None, path='/'):
|
||||||
"""
|
"""
|
||||||
|
@ -87,12 +88,20 @@ class HttpClient(Client):
|
||||||
content['text'] = text
|
content['text'] = text
|
||||||
self.channel_layer.send(to, content)
|
self.channel_layer.send(to, content)
|
||||||
|
|
||||||
def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True):
|
def send_and_consume(self, channel, content={}, text=None, path='/', fail_on_none=True, check_accept=True):
|
||||||
"""
|
"""
|
||||||
Reproduce full life cycle of the message
|
Reproduce full life cycle of the message
|
||||||
"""
|
"""
|
||||||
self.send(channel, content, text, path)
|
self.send(channel, content, text, path)
|
||||||
return self.consume(channel, fail_on_none=fail_on_none)
|
return self.consume(channel, fail_on_none=fail_on_none, check_accept=check_accept)
|
||||||
|
|
||||||
|
def consume(self, channel, fail_on_none=True, check_accept=True):
|
||||||
|
result = super(HttpClient, self).consume(channel, fail_on_none=fail_on_none)
|
||||||
|
if channel == "websocket.connect" and check_accept:
|
||||||
|
received = self.receive(json=False)
|
||||||
|
if received != {"accept": True}:
|
||||||
|
raise AssertionError("Connection rejected: %s != '{accept: True}'" % received)
|
||||||
|
return result
|
||||||
|
|
||||||
def login(self, **credentials):
|
def login(self, **credentials):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,13 +1,11 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
from channels import route_class
|
from channels import route_class
|
||||||
from channels.exceptions import SendNotAvailableOnDemultiplexer
|
from channels.exceptions import SendNotAvailableOnDemultiplexer
|
||||||
from channels.generic import BaseConsumer, websockets
|
from channels.generic import BaseConsumer, websockets
|
||||||
from channels.tests import ChannelTestCase, Client, apply_routes
|
from channels.tests import ChannelTestCase, Client, HttpClient, apply_routes
|
||||||
|
|
||||||
|
|
||||||
@override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache")
|
@override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache")
|
||||||
|
@ -92,6 +90,7 @@ class GenericTests(ChannelTestCase):
|
||||||
class WebsocketConsumer(websockets.WebsocketConsumer):
|
class WebsocketConsumer(websockets.WebsocketConsumer):
|
||||||
|
|
||||||
def connect(self, message, **kwargs):
|
def connect(self, message, **kwargs):
|
||||||
|
self.message.reply_channel.send({'accept': True})
|
||||||
self.send(text=message.get('order'))
|
self.send(text=message.get('order'))
|
||||||
|
|
||||||
routes = [
|
routes = [
|
||||||
|
@ -103,18 +102,18 @@ class GenericTests(ChannelTestCase):
|
||||||
self.assertIs(routes[1].consumer, WebsocketConsumer)
|
self.assertIs(routes[1].consumer, WebsocketConsumer)
|
||||||
|
|
||||||
with apply_routes(routes):
|
with apply_routes(routes):
|
||||||
client = Client()
|
client = HttpClient()
|
||||||
|
|
||||||
client.send('websocket.connect', {'path': '/path', 'order': 1})
|
client.send('websocket.connect', {'path': '/path', 'order': 1})
|
||||||
client.send('websocket.connect', {'path': '/path', 'order': 0})
|
client.send('websocket.connect', {'path': '/path', 'order': 0})
|
||||||
|
client.consume('websocket.connect', check_accept=False)
|
||||||
client.consume('websocket.connect')
|
client.consume('websocket.connect')
|
||||||
|
self.assertEqual(client.receive(json=False), 0)
|
||||||
client.consume('websocket.connect')
|
client.consume('websocket.connect')
|
||||||
client.consume('websocket.connect')
|
self.assertEqual(client.receive(json=False), 1)
|
||||||
self.assertEqual(client.receive(), {'text': 0})
|
|
||||||
self.assertEqual(client.receive(), {'text': 1})
|
|
||||||
|
|
||||||
client.send_and_consume('websocket.connect', {'path': '/path/2', 'order': 'next'})
|
client.send_and_consume('websocket.connect', {'path': '/path/2', 'order': 'next'})
|
||||||
self.assertEqual(client.receive(), {'text': 'next'})
|
self.assertEqual(client.receive(json=False), 'next')
|
||||||
|
|
||||||
def test_as_route_method(self):
|
def test_as_route_method(self):
|
||||||
class WebsocketConsumer(BaseConsumer):
|
class WebsocketConsumer(BaseConsumer):
|
||||||
|
@ -154,40 +153,28 @@ class GenericTests(ChannelTestCase):
|
||||||
"mystream": MyWebsocketConsumer
|
"mystream": MyWebsocketConsumer
|
||||||
}
|
}
|
||||||
|
|
||||||
with apply_routes([
|
with apply_routes([route_class(Demultiplexer, path='/path/(?P<id>\d+)')]):
|
||||||
route_class(Demultiplexer, path='/path/(?P<id>\d+)'),
|
client = HttpClient()
|
||||||
route_class(MyWebsocketConsumer),
|
|
||||||
]):
|
|
||||||
client = Client()
|
|
||||||
|
|
||||||
client.send_and_consume('websocket.connect', {'path': '/path/1'})
|
client.send_and_consume('websocket.connect', path='/path/1')
|
||||||
self.assertEqual(client.receive(), {
|
self.assertEqual(client.receive(), {
|
||||||
"text": json.dumps({
|
"stream": "mystream",
|
||||||
"stream": "mystream",
|
"payload": {"id": "1"},
|
||||||
"payload": {"id": "1"},
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
client.send_and_consume('websocket.receive', {
|
client.send_and_consume('websocket.receive', text={
|
||||||
'path': '/path/1',
|
"stream": "mystream",
|
||||||
'text': json.dumps({
|
"payload": {"text_field": "mytext"},
|
||||||
"stream": "mystream",
|
}, path='/path/1')
|
||||||
"payload": {"text_field": "mytext"}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
self.assertEqual(client.receive(), {
|
self.assertEqual(client.receive(), {
|
||||||
"text": json.dumps({
|
"stream": "mystream",
|
||||||
"stream": "mystream",
|
"payload": {"text_field": "mytext"},
|
||||||
"payload": {"text_field": "mytext"},
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
client.send_and_consume('websocket.disconnect', {'path': '/path/1'})
|
client.send_and_consume('websocket.disconnect', path='/path/1')
|
||||||
self.assertEqual(client.receive(), {
|
self.assertEqual(client.receive(), {
|
||||||
"text": json.dumps({
|
"stream": "mystream",
|
||||||
"stream": "mystream",
|
"payload": {"id": "1"},
|
||||||
"payload": {"id": "1"},
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
def test_websocket_demultiplexer_send(self):
|
def test_websocket_demultiplexer_send(self):
|
||||||
|
@ -202,19 +189,13 @@ class GenericTests(ChannelTestCase):
|
||||||
"mystream": MyWebsocketConsumer
|
"mystream": MyWebsocketConsumer
|
||||||
}
|
}
|
||||||
|
|
||||||
with apply_routes([
|
with apply_routes([route_class(Demultiplexer, path='/path/(?P<id>\d+)')]):
|
||||||
route_class(Demultiplexer, path='/path/(?P<id>\d+)'),
|
client = HttpClient()
|
||||||
route_class(MyWebsocketConsumer),
|
|
||||||
]):
|
|
||||||
client = Client()
|
|
||||||
|
|
||||||
with self.assertRaises(SendNotAvailableOnDemultiplexer):
|
with self.assertRaises(SendNotAvailableOnDemultiplexer):
|
||||||
client.send_and_consume('websocket.receive', {
|
client.send_and_consume('websocket.receive', path='/path/1', text={
|
||||||
'path': '/path/1',
|
"stream": "mystream",
|
||||||
'text': json.dumps({
|
"payload": {"text_field": "mytext"},
|
||||||
"stream": "mystream",
|
|
||||||
"payload": {"text_field": "mytext"}
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
|
||||||
client.receive()
|
client.receive()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user