diff --git a/channels/asgi.py b/channels/asgi.py index 9911578..737a422 100644 --- a/channels/asgi.py +++ b/channels/asgi.py @@ -4,7 +4,7 @@ import django from django.conf import settings from django.utils.module_loading import import_string -from .consumer_registry import ConsumerRegistry +from .routing import Router from .utils import name_that_thing @@ -67,7 +67,7 @@ class ChannelLayerWrapper(object): self.channel_layer = channel_layer self.alias = alias self.routing = routing - self.registry = ConsumerRegistry(self.routing) + self.router = Router(self.routing) def __getattr__(self, name): return getattr(self.channel_layer, name) diff --git a/channels/consumer_registry.py b/channels/consumer_registry.py deleted file mode 100644 index 7cd7a57..0000000 --- a/channels/consumer_registry.py +++ /dev/null @@ -1,76 +0,0 @@ -from __future__ import unicode_literals - -import importlib - -from django.core.exceptions import ImproperlyConfigured -from django.utils import six - -from .handler import ViewConsumer -from .utils import name_that_thing - - -class ConsumerRegistry(object): - """ - Manages the available consumers in the project and which channels they - listen to. - - Generally this is attached to a backend instance as ".registry" - """ - - def __init__(self, routing=None): - self.consumers = {} - # Initialise with any routing that was passed in - if routing: - # If the routing was a string, import it - if isinstance(routing, six.string_types): - module_name, variable_name = routing.rsplit(".", 1) - try: - routing = getattr(importlib.import_module(module_name), variable_name) - except (ImportError, AttributeError) as e: - raise ImproperlyConfigured("Cannot import channel routing %r: %s" % (routing, e)) - # Load consumers into us - for channel, handler in routing.items(): - self.add_consumer(handler, [channel]) - - def add_consumer(self, consumer, channels): - # Upconvert if you just pass in a string for channels - if isinstance(channels, six.string_types): - channels = [channels] - # Make sure all channels are byte strings - channels = [ - channel.decode("ascii") if isinstance(channel, six.binary_type) else channel - for channel in channels - ] - # Import any consumer referenced as string - if isinstance(consumer, six.string_types): - module_name, variable_name = consumer.rsplit(".", 1) - try: - consumer = getattr(importlib.import_module(module_name), variable_name) - except (ImportError, AttributeError): - raise ImproperlyConfigured("Cannot import consumer %r" % consumer) - # Register on each channel, checking it's unique - for channel in channels: - if channel in self.consumers: - raise ValueError("Cannot register consumer %s - channel %r already consumed by %s" % ( - name_that_thing(consumer), - channel, - name_that_thing(self.consumers[channel]), - )) - self.consumers[channel] = consumer - - def all_channel_names(self): - return self.consumers.keys() - - def consumer_for_channel(self, channel): - try: - return self.consumers[channel] - except KeyError: - return None - - def check_default(self, http_consumer=None): - """ - Checks to see if default handlers need to be registered - for channels, and adds them if they need to be. - """ - if not self.consumer_for_channel("http.request"): - self.add_consumer(http_consumer or ViewConsumer(), ["http.request"]) diff --git a/channels/management/commands/runserver.py b/channels/management/commands/runserver.py index 33185b6..3fdd93b 100644 --- a/channels/management/commands/runserver.py +++ b/channels/management/commands/runserver.py @@ -35,7 +35,7 @@ class Command(RunserverCommand): return RunserverCommand.inner_run(self, *args, **options) # Check a handler is registered for http reqs; if not, add default one self.channel_layer = channel_layers[DEFAULT_CHANNEL_LAYER] - self.channel_layer.registry.check_default( + self.channel_layer.router.check_default( http_consumer=self.get_consumer(), ) # Run checks diff --git a/channels/management/commands/runworker.py b/channels/management/commands/runworker.py index 075a323..069e6ec 100644 --- a/channels/management/commands/runworker.py +++ b/channels/management/commands/runworker.py @@ -27,7 +27,7 @@ class Command(BaseCommand): "Change your settings to use a cross-process channel layer." ) # Check a handler is registered for http reqs - self.channel_layer.registry.check_default() + self.channel_layer.router.check_default() # Launch a worker self.logger.info("Running worker against channel layer %s", self.channel_layer) # Optionally provide an output callback diff --git a/channels/routing.py b/channels/routing.py new file mode 100644 index 0000000..8e1fd44 --- /dev/null +++ b/channels/routing.py @@ -0,0 +1,220 @@ +from __future__ import unicode_literals + +import re +import importlib + +from django.core.exceptions import ImproperlyConfigured +from django.utils import six + +from .handler import ViewConsumer +from .utils import name_that_thing + + +class Router(object): + """ + Manages the available consumers in the project and which channels they + listen to. + + Generally this is attached to a backend instance as ".router" + """ + + def __init__(self, routing): + # Resolve routing into a list if it's a dict or string + routing = self.resolve_routing(routing) + # Expand those entries recursively into a flat list of Routes + self.routing = [] + for entry in routing: + self.routing.extend(entry.expand_routes()) + # Now go through that list and collect channel names into a set + self.channels = { + route.channel + for route in self.routing + } + + def add_route(self, route): + """ + Adds a single raw Route to us at the end of the resolution list. + """ + self.routing.append(route) + self.channels.add(route.channel) + + def match(self, message): + """ + Runs through our routing and tries to find a consumer that matches + the message/channel. Returns (consumer, extra_kwargs) if it does, + and None if it doesn't. + """ + # TODO: Maybe we can add some kind of caching in here if we can hash + # the message with only matchable keys faster than the search? + for route in self.routing: + match = route.match(message) + if match is not None: + return match + return None + + def check_default(self, http_consumer=None): + """ + Adds default handlers for Django's default handling of channels. + """ + # We just add the default Django route to the bottom; if the user + # has defined another http.request handler, it'll get hit first and run. + self.add_route(Route("http.request", http_consumer or ViewConsumer())) + + @classmethod + def resolve_routing(cls, routing): + """ + Takes a routing - if it's a string, it imports it, and if it's a + dict, converts it to a list of route()s. Used by this class and Include. + """ + # If the routing was a string, import it + if isinstance(routing, six.string_types): + module_name, variable_name = routing.rsplit(".", 1) + try: + routing = getattr(importlib.import_module(module_name), variable_name) + except (ImportError, AttributeError) as e: + raise ImproperlyConfigured("Cannot import channel routing %r: %s" % (routing, e)) + # If the routing is a dict, convert it + if isinstance(routing, dict): + routing = [ + Route(channel, consumer) + for channel, consumer in routing.items() + ] + return routing + + +class Route(object): + """ + Represents a route to a single consumer, with a channel name + and optional message parameter matching. + """ + + def __init__(self, channel, consumer, **kwargs): + # Get channel, make sure it's a unicode string + self.channel = channel + if isinstance(self.channel, six.binary_type): + self.channel = self.channel.decode("ascii") + # Get consumer, optionally importing it + if isinstance(consumer, six.string_types): + module_name, variable_name = consumer.rsplit(".", 1) + try: + consumer = getattr(importlib.import_module(module_name), variable_name) + except (ImportError, AttributeError): + raise ImproperlyConfigured("Cannot import consumer %r" % consumer) + self.consumer = consumer + # Compile filter regexes up front + self.filters = { + name: re.compile(value) + for name, value in kwargs.items() + } + # Check filters don't use positional groups + for name, regex in self.filters.items(): + if regex.groups != len(regex.groupindex): + raise ValueError( + "Filter for %s on %s contains positional groups; " + "only named groups are allowed." % ( + name, + self, + ) + ) + + def match(self, message): + """ + Checks to see if we match the Message object. Returns + (consumer, kwargs dict) if it matches, None otherwise + """ + # Check for channel match first of all + if message.channel.name != self.channel: + return None + # Check each message filter and build consumer kwargs as we go + call_args = {} + for name, value in self.filters.items(): + if name not in message: + return None + match = re.match(value, message[name]) + # Any match failure means we pass + if match: + call_args.update(match.groupdict()) + else: + return None + return self.consumer, call_args + + def expand_routes(self): + """ + Expands this route into a list of just itself. + """ + return [self] + + def add_prefixes(self, prefixes): + """ + Returns a new Route with the given prefixes added to our filters. + """ + new_filters = {} + # Copy over our filters adding any prefixes + for name, value in self.filters.items(): + if name in prefixes: + if not value.pattern.startswith("^"): + raise ValueError("Cannot add prefix for %s on %s as inner value does not start with ^" % ( + name, + self, + )) + if "$" in prefixes[name]: + raise ValueError("Cannot add prefix for %s on %s as prefix contains $ (end of line match)" % ( + name, + self, + )) + new_filters[name] = re.compile(prefixes[name] + value.pattern.lstrip("^")) + else: + new_filters[name] = value + # Now add any prefixes that are by themselves so they're still enforced + for name, prefix in prefixes.items(): + if name not in new_filters: + new_filters[name] = prefix + # Return new copy + return self.__class__( + self.channel, + self.consumer, + **new_filters + ) + + def __str__(self): + return "%s %s -> %s" % ( + self.channel, + "" if not self.filters else "(%s)" % ( + ", ".join("%s=%s" % (n, v.pattern) for n, v in self.filters.items()) + ), + name_that_thing(self.consumer), + ) + + +class Include(object): + """ + Represents an inclusion of another routing list in another file. + Will automatically modify message match filters to add prefixes, + if specified. + """ + + def __init__(self, routing, **kwargs): + self.routing = Routing.resolve_routing(routing) + self.prefixes = kwargs + # Sanity check prefix regexes + for name, value in self.prefixes.items(): + if not value.startswith("^"): + raise ValueError("Include prefix for %s must start with the ^ character." % name) + + def expand_routes(self): + """ + Expands this Include into a list of routes, first recursively expanding + and then adding on prefixes to filters if specified. + """ + # First, expand our own subset of routes, to get a list of Route objects + routes = [] + for entry in self.routing: + routes.extend(entry.expand_routes()) + # Then, go through those and add any prefixes we have. + routes = [route.add_prefixes(self.prefixes) for route in routes] + return routes + + +# Lowercase standard to match urls.py +route = Route +include = Include diff --git a/channels/tests/test_routing.py b/channels/tests/test_routing.py new file mode 100644 index 0000000..5fcd2a3 --- /dev/null +++ b/channels/tests/test_routing.py @@ -0,0 +1,234 @@ +from __future__ import unicode_literals +from django.test import SimpleTestCase +from django.utils import six + +from channels.routing import Router, route, include +from channels.message import Message +from channels.utils import name_that_thing + + +# Fake consumers and routing sets that can be imported by string +def consumer_1(): + pass +def consumer_2(): + pass +def consumer_3(): + pass +chatroom_routing = [ + route("websocket.connect", consumer_2, path=r"^/chat/(?P[^/]+)/$"), + route("websocket.connect", consumer_3, path=r"^/mentions/$"), +] +chatroom_routing_noprefix = [ + route("websocket.connect", consumer_2, path=r"/chat/(?P[^/]+)/$"), + route("websocket.connect", consumer_3, path=r"/mentions/$"), +] + + +class RoutingTests(SimpleTestCase): + """ + Tests that the router's routing code works correctly. + """ + + # Fake consumers we can test for with the == operator + def consumer_1(self): + pass + def consumer_2(self): + pass + def consumer_3(self): + pass + + def assertRoute(self, router, channel, content, consumer, kwargs=None): + """ + Asserts that asking the `router` to route the `content` as a message + from `channel` means it returns consumer `consumer`, optionally + testing it also returns `kwargs` to be passed in + + Use `consumer` = None to assert that no route is found. + """ + message = Message(content, channel, channel_layer="fake channel layer") + match = router.match(message) + if match is None: + if consumer is None: + return + else: + self.fail("No route found for %s on %s; expecting %s" % ( + content, + channel, + name_that_thing(consumer), + )) + else: + mconsumer, mkwargs = match + if consumer is None: + self.fail("Route found for %s on %s; expecting no route." % ( + content, + channel, + )) + self.assertEqual(consumer, mconsumer, "Route found for %s on %s; but wrong consumer (%s not %s)." % ( + content, + channel, + name_that_thing(mconsumer), + name_that_thing(consumer), + )) + if kwargs is not None: + self.assertEqual(kwargs, mkwargs, "Route found for %s on %s; but wrong kwargs (%s not %s)." % ( + content, + channel, + mkwargs, + kwargs, + )) + + def test_assumption(self): + """ + Ensures the test consumers don't compare equal, as if this ever happens + this test file will pass and miss most bugs. + """ + self.assertNotEqual(consumer_1, consumer_2) + self.assertNotEqual(consumer_1, consumer_3) + + def test_dict(self): + """ + Tests dict expansion + """ + router = Router({ + "http.request": consumer_1, + "http.disconnect": consumer_2, + }) + self.assertRoute( + router, + channel="http.request", + content={}, + consumer=consumer_1, + kwargs={}, + ) + self.assertRoute( + router, + channel="http.request", + content={"path": "/chat/"}, + consumer=consumer_1, + kwargs={}, + ) + self.assertRoute( + router, + channel="http.disconnect", + content={}, + consumer=consumer_2, + kwargs={}, + ) + + def test_filters(self): + """ + Tests that filters catch things correctly. + """ + router = Router([ + route("http.request", consumer_1, path=r"^/chat/$"), + route("http.disconnect", consumer_2), + route("http.request", consumer_3), + ]) + # Filter hit + self.assertRoute( + router, + channel="http.request", + content={"path": "/chat/"}, + consumer=consumer_1, + kwargs={}, + ) + # Fall-through + self.assertRoute( + router, + channel="http.request", + content={}, + consumer=consumer_3, + kwargs={}, + ) + self.assertRoute( + router, + channel="http.request", + content={"path": "/liveblog/"}, + consumer=consumer_3, + kwargs={}, + ) + + def test_include(self): + """ + Tests inclusion without a prefix + """ + router = Router([ + include("channels.tests.test_routing.chatroom_routing"), + ]) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/boom/"}, + consumer=None, + ) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/chat/django/"}, + consumer=consumer_2, + kwargs={"room": "django"}, + ) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/mentions/"}, + consumer=consumer_3, + kwargs={}, + ) + + def test_include_prefix(self): + """ + Tests inclusion with a prefix + """ + router = Router([ + include("channels.tests.test_routing.chatroom_routing", path="^/ws/v(?P[0-9]+)"), + ]) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/boom/"}, + consumer=None, + ) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/chat/django/"}, + consumer=None, + ) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/ws/v2/chat/django/"}, + consumer=consumer_2, + kwargs={"version": "2", "room": "django"}, + ) + self.assertRoute( + router, + channel="websocket.connect", + content={"path": "/ws/v1/mentions/"}, + consumer=consumer_3, + kwargs={"version": "1"}, + ) + + def test_positional_pattern(self): + """ + Tests that regexes with positional groups are rejected. + """ + with self.assertRaises(ValueError): + Consumerrouter([ + route("http.request", consumer_1, path=r"^/chat/([^/]+)/$"), + ]) + + def test_bad_include_prefix(self): + """ + Tests both failure cases of prefixes for includes - the include not + starting with ^, and the included filter not starting with ^. + """ + with self.assertRaises(ValueError): + Consumerrouter([ + include("channels.tests.test_routing.chatroom_routing", path="foobar"), + ]) + with self.assertRaises(ValueError): + Consumerrouter([ + include("channels.tests.test_routing.chatroom_routing_noprefix", path="^/foobar/"), + ]) diff --git a/channels/worker.py b/channels/worker.py index ce3a7f8..4f8278f 100644 --- a/channels/worker.py +++ b/channels/worker.py @@ -44,7 +44,7 @@ class Worker(object): """ if self.signal_handlers: self.install_signal_handler() - channels = self.channel_layer.registry.all_channel_names() + channels = self.channel_layer.router.channels while not self.termed: self.in_job = False channel, content = self.channel_layer.receive_many(channels, block=True) @@ -66,11 +66,16 @@ class Worker(object): if content.get("__retries__", 0) == self.message_retries: message.__doomed__ = True # Handle the message - consumer = self.channel_layer.registry.consumer_for_channel(channel) + match = self.channel_layer.router.match(message) + if match is None: + logger.exception("Could not find match for message on %s! Check your routing.", channel) + continue + else: + consumer, kwargs = match if self.callback: self.callback(channel, message) try: - consumer(message) + consumer(message, **kwargs) except ConsumeLater: # They want to not handle it yet. Re-inject it with a number-of-tries marker. content['__retries__'] = content.get("__retries__", 0) + 1 diff --git a/docs/getting-started.rst b/docs/getting-started.rst index 25af12a..4ad97d8 100644 --- a/docs/getting-started.rst +++ b/docs/getting-started.rst @@ -62,9 +62,10 @@ Here's what that looks like:: } # In routing.py - channel_routing = { - "http.request": "myproject.myapp.consumers.http_consumer", - } + from channels.routing import route + channel_routing = [ + route("http.request", "myproject.myapp.consumers.http_consumer"), + ] .. warning:: This example, and most of the examples here, use the "in memory" channel @@ -76,7 +77,7 @@ Here's what that looks like:: As you can see, this is a little like Django's ``DATABASES`` setting; there are named channel layers, with a default one called ``default``. Each layer needs a channel layer class, some options (if the channel layer needs them), -and a routing scheme, which points to a dict containing the routing settings. +and a routing scheme, which points to a list containing the routing settings. It's recommended you call this ``routing.py`` and put it alongside ``urls.py`` in your project, but you can put it wherever you like, as long as the path is correct. @@ -111,11 +112,12 @@ for ``http.request`` - and make this WebSocket consumer instead:: Hook it up to the ``websocket.receive`` channel like this:: # In routing.py + from channels.routing import route from myproject.myapp.consumers import ws_message - channel_routing = { - "websocket.receive": ws_message, - } + channel_routing = [ + route("websocket.receive", ws_message), + ] Now, let's look at what this is doing. It's tied to the ``websocket.receive`` channel, which means that it'll get a message @@ -210,12 +212,13 @@ get the message. Here's all the code:: And what our routing should look like in ``routing.py``:: + from channels.routing import route from myproject.myapp.consumers import ws_add, ws_message, ws_disconnect - channel_routing = { - "websocket.connect": ws_add, - "websocket.receive": ws_message, - "websocket.disconnect": ws_disconnect, + channel_routing = [ + route("websocket.connect", ws_add), + route("websocket.receive", ws_message), + route("websocket.disconnect", ws_disconnect), } With all that code, you now have a working set of a logic for a chat server. @@ -366,6 +369,7 @@ If you play around with it from the console (or start building a simple JavaScript chat client that appends received messages to a div), you'll see that you can set a chat room with the initial request. + Authentication -------------- @@ -430,8 +434,6 @@ chat to people with the same first letter of their username:: # Connected to websocket.connect @channel_session_user_from_http def ws_add(message): - # Copy user from HTTP to channel session - transfer_user(message.http_session, message.channel_session) # Add them to the right group Group("chat-%s" % message.user.username[0]).add(message.reply_channel) @@ -458,6 +460,58 @@ responses can set cookies, it needs a backend it can write to to separately store state. +Routing +------- + +Channels' ``routing.py`` acts very much like Django's ``urls.py``, including the +ability to route things to different consumers based on ``path``, or any other +message attribute that's a string (for example, ``http.request`` messages have +a ``method`` key you could route based on). + +Much like urls, you route using regular expressions; the main difference is that +because the ``path`` is not special-cased - Channels doesn't know that it's a URL - +you have to start patterns with the root ``/``, and end includes without a ``/`` +so that when the patterns combine, they work correctly. + +Finally, because you're matching against message contents using keyword arguments, +you can only use named groups in your regular expressions! Here's an example of +routing our chat from above:: + + http_routing = [ + route("http.request", poll_consumer, path=r"^/poll/$", method=r"^POST$"), + ] + + chat_routing = [ + route("websocket.connect", chat_connect, path=r"^/(?P[a-zA-Z0-9_]+)/$), + route("websocket.disconnect", chat_disconnect), + ] + + routing = [ + # You can use a string import path as the first argument as well. + include(chat_routing, path=r"^/chat"), + include(http_routing), + ] + +When Channels loads this routing, it appends any match keys together, so the +``path`` match becomes ``^/chat/(?P[a-zA-Z0-9_]+)/$``. If the include match +or the route match doesn't have the ``^`` character, it will refuse to append them +and error (you can still have matches without ``^`` in either, you just can't +ask Channels to combine them). + +Because these matches come through as keyword arguments, we could modify our +consumer above to use a room based on URL rather than username:: + + # Connected to websocket.connect + @channel_session_user_from_http + def ws_add(message, room): + # Add them to the right group + Group("chat-%s" % room).add(message.reply_channel) + +In the next section, we'll change to sending the ``room`` as a part of the +WebSocket message - which you might do if you had a multiplexing client - +but you could use routing there as well. + + Models ------ diff --git a/docs/reference.rst b/docs/reference.rst index 5b875a5..adc2d1d 100644 --- a/docs/reference.rst +++ b/docs/reference.rst @@ -117,19 +117,15 @@ They have the following attributes: * ``alias``: The alias of this layer. -* ``registry``: An object which represents the layer's mapping of channels +* ``router``: An object which represents the layer's mapping of channels to consumers. Has the following attributes: - * ``add_consumer(consumer, channels)``: Registers a :ref:`consumer ` - to handle all channels passed in. ``channels`` should be an iterable of - unicode string names. + * ``channels``: The set of channels this router can handle, as unicode strings - * ``consumer_for_channel(channel)``: Takes a unicode channel name and returns - either a :ref:`consumer `, or None, if no consumer is registered. - - * ``all_channel_names()``: Returns a list of all channel names this layer has - routed to a consumer. Used by the worker threads to work out what channels - to listen on. + * ``match(message)``: Takes a :ref:`Message ` and returns either + a (consumer, kwargs) tuple specifying the consumer to run and the keyword + argument to pass that were extracted via routing patterns, or None, + meaning there's no route available. .. _ref-asgirequest: