From 2101f285cb735c73cda3e273e249725b0729d232 Mon Sep 17 00:00:00 2001 From: Doug Keen Date: Tue, 28 Feb 2017 18:51:48 -0800 Subject: [PATCH] Allow custom json encoder and decoder in `JsonWebsocketConsumer` (#535) Lets you override the JSON encoding on both the consumer and the multiplexer. --- channels/generic/websockets.py | 103 +++++++++++++++++++-------------- docs/generics.rst | 14 ++++- tests/test_generic.py | 68 ++++++++++++++++++++++ 3 files changed, 139 insertions(+), 46 deletions(-) diff --git a/channels/generic/websockets.py b/channels/generic/websockets.py index 3e77e3b..f4bb714 100644 --- a/channels/generic/websockets.py +++ b/channels/generic/websockets.py @@ -148,7 +148,7 @@ class JsonWebsocketConsumer(WebsocketConsumer): def raw_receive(self, message, **kwargs): if "text" in message: - self.receive(json.loads(message['text']), **kwargs) + self.receive(self.decode_json(message['text']), **kwargs) else: raise ValueError("No text section for incoming WebSocket frame!") @@ -162,11 +162,58 @@ class JsonWebsocketConsumer(WebsocketConsumer): """ Encode the given content as JSON and send it to the client. """ - super(JsonWebsocketConsumer, self).send(text=json.dumps(content), close=close) + super(JsonWebsocketConsumer, self).send(text=self.encode_json(content), close=close) + + @classmethod + def decode_json(cls, text): + return json.loads(text) + + @classmethod + def encode_json(cls, content): + return json.dumps(content) @classmethod def group_send(cls, name, content, close=False): - WebsocketConsumer.group_send(name, json.dumps(content), close=close) + WebsocketConsumer.group_send(name, cls.encode_json(content), close=close) + + +class WebsocketMultiplexer(object): + """ + The opposite of the demultiplexer, to send a message though a multiplexed channel. + + The multiplexer object is passed as a kwargs to the consumer when the message is dispatched. + This pattern allows the consumer class to be independent of the stream name. + """ + + stream = None + reply_channel = None + + def __init__(self, stream, reply_channel): + self.stream = stream + self.reply_channel = reply_channel + + def send(self, payload): + """Multiplex the payload using the stream name and send it.""" + self.reply_channel.send(self.encode(self.stream, payload)) + + @classmethod + def encode_json(cls, content): + return json.dumps(content, cls=DjangoJSONEncoder) + + @classmethod + def encode(cls, stream, payload): + """ + Encodes stream + payload for outbound sending. + """ + content = {"stream": stream, "payload": payload} + return {"text": cls.encode_json(content)} + + @classmethod + def group_send(cls, name, stream, payload, close=False): + message = cls.encode(stream, payload) + if close: + message["close"] = True + Group(name).send(message) class WebsocketDemultiplexer(JsonWebsocketConsumer): @@ -191,6 +238,9 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer): # Put your JSON consumers here: {stream_name : consumer} consumers = {} + # Optionally use a custom multiplexer class + multiplexer_class = WebsocketMultiplexer + def receive(self, content, **kwargs): """Forward messages to all consumers.""" # Check the frame looks good @@ -203,10 +253,10 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer): if not isinstance(payload, dict): raise ValueError("Multiplexed frame payload is not a dict") # The json consumer expects serialized JSON - self.message.content['text'] = json.dumps(payload) + self.message.content['text'] = self.encode_json(payload) # Send demultiplexer to the consumer, to be able to answer - kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel) - # Patch send to avoid sending not formated messages from the consumer + kwargs['multiplexer'] = self.multiplexer_class(stream, self.message.reply_channel) + # Patch send to avoid sending not formatted messages from the consumer if hasattr(consumer, "send"): consumer.send = self.send # Dispatch message @@ -221,13 +271,13 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer): """Forward connection to all consumers.""" self.message.reply_channel.send({"accept": True}) for stream, consumer in self.consumers.items(): - kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel) + kwargs['multiplexer'] = self.multiplexer_class(stream, self.message.reply_channel) consumer(message, **kwargs) def disconnect(self, message, **kwargs): """Forward disconnection to all consumers.""" for stream, consumer in self.consumers.items(): - kwargs['multiplexer'] = WebsocketMultiplexer(stream, self.message.reply_channel) + kwargs['multiplexer'] = self.multiplexer_class(stream, self.message.reply_channel) consumer(message, **kwargs) def send(self, *args): @@ -236,40 +286,3 @@ class WebsocketDemultiplexer(JsonWebsocketConsumer): @classmethod def group_send(cls, name, stream, payload, close=False): raise SendNotAvailableOnDemultiplexer("Use WebsocketMultiplexer.group_send") - - -class WebsocketMultiplexer(object): - """ - The opposite of the demultiplexer, to send a message though a multiplexed channel. - - The multiplexer object is passed as a kwargs to the consumer when the message is dispatched. - This pattern allows the consumer class to be independant of the stream name. - """ - - stream = None - reply_channel = None - - def __init__(self, stream, reply_channel): - self.stream = stream - self.reply_channel = reply_channel - - def send(self, payload): - """Multiplex the payload using the stream name and send it.""" - self.reply_channel.send(self.encode(self.stream, payload)) - - @classmethod - def encode(cls, stream, payload): - """ - Encodes stream + payload for outbound sending. - """ - return {"text": json.dumps({ - "stream": stream, - "payload": payload, - }, cls=DjangoJSONEncoder)} - - @classmethod - def group_send(cls, name, stream, payload, close=False): - message = WebsocketMultiplexer.encode(stream, payload) - if close: - message["close"] = True - Group(name).send(message) diff --git a/docs/generics.rst b/docs/generics.rst index 459ce36..17768da 100644 --- a/docs/generics.rst +++ b/docs/generics.rst @@ -163,6 +163,15 @@ The JSON-enabled consumer looks slightly different:: """ pass + # Optionally provide your own custom json encoder and decoder + # @classmethod + # def decode_json(cls, text): + # return my_custom_json_decoder(text) + # + # @classmethod + # def encode_json(cls, content): + # return my_custom_json_encoder(content) + For this subclass, ``receive`` only gets a ``content`` argument that is the already-decoded JSON as Python datastructures; similarly, ``send`` now only takes a single argument, which it JSON-encodes before sending down to the @@ -221,8 +230,11 @@ Example using class-based consumer:: "other": AnotherConsumer, } + # Optionally provide a custom multiplexer class + # multiplexer_class = MyCustomJsonEncodingMultiplexer -The ``multiplexer`` allows the consumer class to be independant of the stream name. + +The ``multiplexer`` allows the consumer class to be independent of the stream name. It holds the stream name and the demultiplexer on the attributes ``stream`` and ``demultiplexer``. The :doc:`data binding ` code will also send out messages to clients diff --git a/tests/test_generic.py b/tests/test_generic.py index 1751ecf..6b0c83a 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals +import json + from django.test import override_settings from channels import route_class @@ -199,3 +201,69 @@ class GenericTests(ChannelTestCase): }) client.receive() + + def test_websocket_custom_json_serialization(self): + + class WebsocketConsumer(websockets.JsonWebsocketConsumer): + @classmethod + def decode_json(cls, text): + obj = json.loads(text) + return dict((key.upper(), obj[key]) for key in obj) + + @classmethod + def encode_json(cls, content): + lowered = dict((key.lower(), content[key]) for key in content) + return json.dumps(lowered) + + def receive(self, content, multiplexer=None, **kwargs): + self.content_received = content + self.send({"RESPONSE": "HI"}) + + class MyMultiplexer(websockets.WebsocketMultiplexer): + @classmethod + def encode_json(cls, content): + lowered = dict((key.lower(), content[key]) for key in content) + return json.dumps(lowered) + + with apply_routes([route_class(WebsocketConsumer, path='/path')]): + client = HttpClient() + + consumer = client.send_and_consume('websocket.receive', path='/path', text={"key": "value"}) + self.assertEqual(consumer.content_received, {"KEY": "value"}) + + self.assertEqual(client.receive(), {"response": "HI"}) + + client.join_group('test_group') + WebsocketConsumer.group_send('test_group', {"KEY": "VALUE"}) + self.assertEqual(client.receive(), {"key": "VALUE"}) + + def test_websockets_demultiplexer_custom_multiplexer(self): + + class MyWebsocketConsumer(websockets.JsonWebsocketConsumer): + def connect(self, message, multiplexer=None, **kwargs): + multiplexer.send({"THIS_SHOULD_BE_LOWERCASED": "1"}) + + class MyMultiplexer(websockets.WebsocketMultiplexer): + @classmethod + def encode_json(cls, content): + lowered = { + "stream": content["stream"], + "payload": dict((key.lower(), content["payload"][key]) for key in content["payload"]) + } + return json.dumps(lowered) + + class Demultiplexer(websockets.WebsocketDemultiplexer): + multiplexer_class = MyMultiplexer + + consumers = { + "mystream": MyWebsocketConsumer + } + + with apply_routes([route_class(Demultiplexer, path='/path/(?P\d+)')]): + client = HttpClient() + + client.send_and_consume('websocket.connect', path='/path/1') + self.assertEqual(client.receive(), { + "stream": "mystream", + "payload": {"this_should_be_lowercased": "1"}, + })