diff --git a/channels/generic/base.py b/channels/generic/base.py index 89c387c..5e3c82a 100644 --- a/channels/generic/base.py +++ b/channels/generic/base.py @@ -6,11 +6,11 @@ class BaseConsumer(object): Base class-based consumer class. Provides the mechanisms to be a direct routing object and a few other things. - Class-based consumers should be used directly in routing with their - filters, like so:: + Class-based consumers should be used with route_class in routing, like so:: + from channels import route_class routing = [ - JsonWebsocketConsumer(path=r"^/liveblog/(?P[^/]+)/"), + route_class(JsonWebsocketConsumer, path=r"^/liveblog/(?P[^/]+)/"), ] """ @@ -32,9 +32,14 @@ class BaseConsumer(object): """ return set(cls.method_mapping.keys()) + def get_handler(self, message, **kwargs): + """ + Return handler uses method_mapping to return the right method to call. + """ + return getattr(self, self.method_mapping[message.channel.name]) + def dispatch(self, message, **kwargs): """ - Called with the message and all keyword arguments; uses method_mapping - to choose the right method to call. + Call handler with the message and all keyword arguments. """ - return getattr(self, self.method_mapping[message.channel.name])(message, **kwargs) + return self.get_handler(message, **kwargs)(message, **kwargs) diff --git a/channels/generic/websockets.py b/channels/generic/websockets.py index af81ab2..b633fc0 100644 --- a/channels/generic/websockets.py +++ b/channels/generic/websockets.py @@ -22,18 +22,19 @@ class WebsocketConsumer(BaseConsumer): slight_ordering = False strict_ordering = False - def dispatch(self, message, **kwargs): + def get_handler(self, message, **kwargs): """ Pulls out the path onto an instance variable, and optionally adds the ordering decorator. """ self.path = message['path'] + handler = super(WebsocketConsumer, self).get_handler(message, **kwargs) if self.strict_ordering: - return enforce_ordering(super(WebsocketConsumer, self).dispatch(message, **kwargs), slight=False) + return enforce_ordering(handler, slight=False) elif self.slight_ordering: - return enforce_ordering(super(WebsocketConsumer, self).dispatch(message, **kwargs), slight=True) + return enforce_ordering(handler, slight=True) else: - return super(WebsocketConsumer, self).dispatch(message, **kwargs) + return handler def connection_groups(self, **kwargs): """ diff --git a/channels/tests/test_generic.py b/channels/tests/test_generic.py new file mode 100644 index 0000000..6a8c380 --- /dev/null +++ b/channels/tests/test_generic.py @@ -0,0 +1,85 @@ +from __future__ import unicode_literals + +from django.test import override_settings +from channels import route_class +from channels.generic import BaseConsumer, websockets +from channels.tests import ChannelTestCase +from channels.tests import apply_routes, Client + + +@override_settings(SESSION_ENGINE="django.contrib.sessions.backends.cache") +class GenericTests(ChannelTestCase): + + def test_base_consumer(self): + + class Consumers(BaseConsumer): + + method_mapping = { + 'test.create': 'create', + 'test.test': 'test', + } + + def create(self, message, **kwargs): + self.called = 'create' + + def test(self, message, **kwargs): + self.called = 'test' + + with apply_routes([route_class(Consumers)]): + client = Client() + + # check that methods for certain channels routes successfully + self.assertEqual(client.send_and_consume('test.create').called, 'create') + self.assertEqual(client.send_and_consume('test.test').called, 'test') + + # send to the channels without routes + client.send('test.wrong') + message = self.get_next_message('test.wrong') + self.assertEqual(client.channel_layer.router.match(message), None) + + client.send('test') + message = self.get_next_message('test') + self.assertEqual(client.channel_layer.router.match(message), None) + + def test_websockets_consumers_handlers(self): + + class WebsocketConsumer(websockets.WebsocketConsumer): + + def connect(self, message, **kwargs): + self.called = 'connect' + self.id = kwargs['id'] + + def disconnect(self, message, **kwargs): + self.called = 'disconnect' + + def receive(self, text=None, bytes=None, **kwargs): + self.text = text + + with apply_routes([route_class(WebsocketConsumer, path='/path/(?P\d+)')]): + client = Client() + + consumer = client.send_and_consume('websocket.connect', {'path': '/path/1'}) + self.assertEqual(consumer.called, 'connect') + self.assertEqual(consumer.id, '1') + + consumer = client.send_and_consume('websocket.receive', {'path': '/path/1', 'text': 'text'}) + self.assertEqual(consumer.text, 'text') + + consumer = client.send_and_consume('websocket.disconnect', {'path': '/path/1'}) + self.assertEqual(consumer.called, 'disconnect') + + def test_websockets_decorators(self): + class WebsocketConsumer(websockets.WebsocketConsumer): + slight_ordering = True + + def connect(self, message, **kwargs): + self.order = message['order'] + + with apply_routes([route_class(WebsocketConsumer, path='/path')]): + client = Client() + + client.send('websocket.connect', {'path': '/path', 'order': 1}) + client.send('websocket.connect', {'path': '/path', 'order': 0}) + client.consume('websocket.connect') + self.assertEqual(client.consume('websocket.connect').order, 0) + self.assertEqual(client.consume('websocket.connect').order, 1)