daphne/channels/routing.py

265 lines
9.3 KiB
Python

from __future__ import unicode_literals
import re
import importlib
from django.core.exceptions import ImproperlyConfigured
from django.utils import six
from .utils import name_that_thing
class Router(object):
"""
Manages the available consumers in the project and which channels they
listen to.
Generally this is attached to a backend instance as ".router"
Anything can be a routable object as long as it provides a match()
method that either returns (callable, kwargs) or None.
"""
def __init__(self, routing):
# Use a blank include as the root item
self.root = Include(routing)
# Cache channel names
self.channels = self.root.channel_names()
def add_route(self, route):
"""
Adds a single raw Route to us at the end of the resolution list.
"""
self.root.routing.append(route)
self.channels = self.root.channel_names()
def match(self, message):
"""
Runs through our routing and tries to find a consumer that matches
the message/channel. Returns (consumer, extra_kwargs) if it does,
and None if it doesn't.
"""
# TODO: Maybe we can add some kind of caching in here if we can hash
# the message with only matchable keys faster than the search?
return self.root.match(message)
def check_default(self, http_consumer=None):
"""
Adds default handlers for Django's default handling of channels.
"""
# We just add the default Django route to the bottom; if the user
# has defined another http.request handler, it'll get hit first and run.
# Inner import here to avoid circular import; this function only gets
# called once, thankfully.
from .handler import ViewConsumer
self.add_route(Route("http.request", http_consumer or ViewConsumer()))
# We also add a no-op websocket.connect consumer to the bottom, as the
# spec requires that this is consumed, but Channels does not. Any user
# consumer will override this one. Same for websocket.receive.
self.add_route(Route("websocket.connect", connect_consumer))
self.add_route(Route("websocket.receive", null_consumer))
self.add_route(Route("websocket.disconnect", null_consumer))
@classmethod
def resolve_routing(cls, routing):
"""
Takes a routing - if it's a string, it imports it, and if it's a
dict, converts it to a list of route()s. Used by this class and Include.
"""
# If the routing was a string, import it
if isinstance(routing, six.string_types):
module_name, variable_name = routing.rsplit(".", 1)
try:
routing = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError) as e:
raise ImproperlyConfigured("Cannot import channel routing %r: %s" % (routing, e))
# If the routing is a dict, convert it
if isinstance(routing, dict):
routing = [
Route(channel, consumer)
for channel, consumer in routing.items()
]
return routing
@classmethod
def normalise_re_arg(cls, value):
"""
Normalises regular expression patterns and string inputs to Unicode.
"""
if isinstance(value, six.binary_type):
return value.decode("ascii")
else:
return value
class Route(object):
"""
Represents a route to a single consumer, with a channel name
and optional message parameter matching.
"""
def __init__(self, channels, consumer, **kwargs):
# Get channels, make sure it's a list of unicode strings
if isinstance(channels, six.string_types):
channels = [channels]
self.channels = [
channel.decode("ascii") if isinstance(channel, six.binary_type) else channel
for channel in channels
]
# Get consumer, optionally importing it
self.consumer = self._resolve_consumer(consumer)
# Compile filter regexes up front
self.filters = {
name: re.compile(Router.normalise_re_arg(value))
for name, value in kwargs.items()
}
# Check filters don't use positional groups
for name, regex in self.filters.items():
if regex.groups != len(regex.groupindex):
raise ValueError(
"Filter for %s on %s contains positional groups; "
"only named groups are allowed." % (
name,
self,
)
)
def _resolve_consumer(self, consumer):
"""
Turns the consumer from a string into an object if it's a string,
passes it through otherwise.
"""
if isinstance(consumer, six.string_types):
module_name, variable_name = consumer.rsplit(".", 1)
try:
consumer = getattr(importlib.import_module(module_name), variable_name)
except (ImportError, AttributeError):
raise ImproperlyConfigured("Cannot import consumer %r" % consumer)
return consumer
def match(self, message):
"""
Checks to see if we match the Message object. Returns
(consumer, kwargs dict) if it matches, None otherwise
"""
# Check for channel match first of all
if message.channel.name not in self.channels:
return None
# Check each message filter and build consumer kwargs as we go
call_args = {}
for name, value in self.filters.items():
if name not in message:
return None
match = value.match(Router.normalise_re_arg(message[name]))
# Any match failure means we pass
if match:
call_args.update(match.groupdict())
else:
return None
return self.consumer, call_args
def channel_names(self):
"""
Returns the channel names this route listens on
"""
return set(self.channels)
def __str__(self):
return "%s %s -> %s" % (
"/".join(self.channels),
"" if not self.filters else "(%s)" % (
", ".join("%s=%s" % (n, v.pattern) for n, v in self.filters.items())
),
name_that_thing(self.consumer),
)
class RouteClass(Route):
"""
Like Route, but targets a class-based consumer rather than a functional
one, meaning it looks for a (class) method called "channels()" on the
object rather than having a single channel passed in.
"""
def __init__(self, consumer, **kwargs):
# Check the consumer provides a method_channels
consumer = self._resolve_consumer(consumer)
if not hasattr(consumer, "channel_names") or not callable(consumer.channel_names):
raise ValueError("The consumer passed to RouteClass has no valid channel_names method")
# Call super with list of channels
super(RouteClass, self).__init__(consumer.channel_names(), consumer, **kwargs)
class Include(object):
"""
Represents an inclusion of another routing list in another file.
Will automatically modify message match filters to add prefixes,
if specified.
"""
def __init__(self, routing, **kwargs):
self.routing = Router.resolve_routing(routing)
self.prefixes = {
name: re.compile(Router.normalise_re_arg(value))
for name, value in kwargs.items()
}
def match(self, message):
"""
Tries to match the message against our own prefixes, possibly modifying
what we send to included things, then tries all included items.
"""
# Check our prefixes match. Do this against a copy of the message so
# we can write back any changed values.
message = message.copy()
call_args = {}
for name, prefix in self.prefixes.items():
if name not in message:
return None
value = Router.normalise_re_arg(message[name])
match = prefix.match(value)
# Any match failure means we pass
if match:
call_args.update(match.groupdict())
# Modify the message value to remove the part we matched on
message[name] = value[match.end():]
else:
return None
# Alright, if we got this far our prefixes match. Try all of our
# included objects now.
for entry in self.routing:
match = entry.match(message)
if match is not None:
call_args.update(match[1])
return match[0], call_args
# Nothing matched :(
return None
def channel_names(self):
"""
Returns the channel names this route listens on
"""
result = set()
for entry in self.routing:
result.update(entry.channel_names())
return result
def null_consumer(*args, **kwargs):
"""
Standard no-op consumer.
"""
def connect_consumer(message, *args, **kwargs):
"""
Accept-all-connections websocket.connect consumer
"""
message.reply_channel.send({"accept": True})
# Lowercase standard to match urls.py
route = Route
route_class = RouteClass
include = Include