mirror of
				https://github.com/django/daphne.git
				synced 2025-11-04 01:27:33 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			246 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			246 lines
		
	
	
		
			8.6 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from __future__ import unicode_literals
 | 
						|
 | 
						|
import re
 | 
						|
import importlib
 | 
						|
 | 
						|
from django.core.exceptions import ImproperlyConfigured
 | 
						|
from django.utils import six
 | 
						|
 | 
						|
from .utils import name_that_thing
 | 
						|
 | 
						|
 | 
						|
class Router(object):
 | 
						|
    """
 | 
						|
    Manages the available consumers in the project and which channels they
 | 
						|
    listen to.
 | 
						|
 | 
						|
    Generally this is attached to a backend instance as ".router"
 | 
						|
 | 
						|
    Anything can be a routable object as long as it provides a match()
 | 
						|
    method that either returns (callable, kwargs) or None.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, routing):
 | 
						|
        # Use a blank include as the root item
 | 
						|
        self.root = Include(routing)
 | 
						|
        # Cache channel names
 | 
						|
        self.channels = self.root.channel_names()
 | 
						|
 | 
						|
    def add_route(self, route):
 | 
						|
        """
 | 
						|
        Adds a single raw Route to us at the end of the resolution list.
 | 
						|
        """
 | 
						|
        self.root.routing.append(route)
 | 
						|
        self.channels = self.root.channel_names()
 | 
						|
 | 
						|
    def match(self, message):
 | 
						|
        """
 | 
						|
        Runs through our routing and tries to find a consumer that matches
 | 
						|
        the message/channel. Returns (consumer, extra_kwargs) if it does,
 | 
						|
        and None if it doesn't.
 | 
						|
        """
 | 
						|
        # TODO: Maybe we can add some kind of caching in here if we can hash
 | 
						|
        # the message with only matchable keys faster than the search?
 | 
						|
        return self.root.match(message)
 | 
						|
 | 
						|
    def check_default(self, http_consumer=None):
 | 
						|
        """
 | 
						|
        Adds default handlers for Django's default handling of channels.
 | 
						|
        """
 | 
						|
        # We just add the default Django route to the bottom; if the user
 | 
						|
        # has defined another http.request handler, it'll get hit first and run.
 | 
						|
        # Inner import here to avoid circular import; this function only gets
 | 
						|
        # called once, thankfully.
 | 
						|
        from .handler import ViewConsumer
 | 
						|
        self.add_route(Route("http.request", http_consumer or ViewConsumer()))
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def resolve_routing(cls, routing):
 | 
						|
        """
 | 
						|
        Takes a routing - if it's a string, it imports it, and if it's a
 | 
						|
        dict, converts it to a list of route()s. Used by this class and Include.
 | 
						|
        """
 | 
						|
        # If the routing was a string, import it
 | 
						|
        if isinstance(routing, six.string_types):
 | 
						|
            module_name, variable_name = routing.rsplit(".", 1)
 | 
						|
            try:
 | 
						|
                routing = getattr(importlib.import_module(module_name), variable_name)
 | 
						|
            except (ImportError, AttributeError) as e:
 | 
						|
                raise ImproperlyConfigured("Cannot import channel routing %r: %s" % (routing, e))
 | 
						|
        # If the routing is a dict, convert it
 | 
						|
        if isinstance(routing, dict):
 | 
						|
            routing = [
 | 
						|
                Route(channel, consumer)
 | 
						|
                for channel, consumer in routing.items()
 | 
						|
            ]
 | 
						|
        return routing
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def normalise_re_arg(cls, value):
 | 
						|
        """
 | 
						|
        Normalises regular expression patterns and string inputs to Unicode.
 | 
						|
        """
 | 
						|
        if isinstance(value, six.binary_type):
 | 
						|
            return value.decode("ascii")
 | 
						|
        else:
 | 
						|
            return value
 | 
						|
 | 
						|
 | 
						|
class Route(object):
 | 
						|
    """
 | 
						|
    Represents a route to a single consumer, with a channel name
 | 
						|
    and optional message parameter matching.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, channels, consumer, **kwargs):
 | 
						|
        # Get channels, make sure it's a list of unicode strings
 | 
						|
        if isinstance(channels, six.string_types):
 | 
						|
            channels = [channels]
 | 
						|
        self.channels = [
 | 
						|
            channel.decode("ascii") if isinstance(channel, six.binary_type) else channel
 | 
						|
            for channel in channels
 | 
						|
        ]
 | 
						|
        # Get consumer, optionally importing it
 | 
						|
        self.consumer = self._resolve_consumer(consumer)
 | 
						|
        # Compile filter regexes up front
 | 
						|
        self.filters = {
 | 
						|
            name: re.compile(Router.normalise_re_arg(value))
 | 
						|
            for name, value in kwargs.items()
 | 
						|
        }
 | 
						|
        # Check filters don't use positional groups
 | 
						|
        for name, regex in self.filters.items():
 | 
						|
            if regex.groups != len(regex.groupindex):
 | 
						|
                raise ValueError(
 | 
						|
                    "Filter for %s on %s contains positional groups; "
 | 
						|
                    "only named groups are allowed." % (
 | 
						|
                        name,
 | 
						|
                        self,
 | 
						|
                    )
 | 
						|
                )
 | 
						|
 | 
						|
    def _resolve_consumer(self, consumer):
 | 
						|
        """
 | 
						|
        Turns the consumer from a string into an object if it's a string,
 | 
						|
        passes it through otherwise.
 | 
						|
        """
 | 
						|
        if isinstance(consumer, six.string_types):
 | 
						|
            module_name, variable_name = consumer.rsplit(".", 1)
 | 
						|
            try:
 | 
						|
                consumer = getattr(importlib.import_module(module_name), variable_name)
 | 
						|
            except (ImportError, AttributeError):
 | 
						|
                raise ImproperlyConfigured("Cannot import consumer %r" % consumer)
 | 
						|
        return consumer
 | 
						|
 | 
						|
    def match(self, message):
 | 
						|
        """
 | 
						|
        Checks to see if we match the Message object. Returns
 | 
						|
        (consumer, kwargs dict) if it matches, None otherwise
 | 
						|
        """
 | 
						|
        # Check for channel match first of all
 | 
						|
        if message.channel.name not in self.channels:
 | 
						|
            return None
 | 
						|
        # Check each message filter and build consumer kwargs as we go
 | 
						|
        call_args = {}
 | 
						|
        for name, value in self.filters.items():
 | 
						|
            if name not in message:
 | 
						|
                return None
 | 
						|
            match = value.match(Router.normalise_re_arg(message[name]))
 | 
						|
            # Any match failure means we pass
 | 
						|
            if match:
 | 
						|
                call_args.update(match.groupdict())
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        return self.consumer, call_args
 | 
						|
 | 
						|
    def channel_names(self):
 | 
						|
        """
 | 
						|
        Returns the channel names this route listens on
 | 
						|
        """
 | 
						|
        return set(self.channels)
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return "%s %s -> %s" % (
 | 
						|
            "/".join(self.channels),
 | 
						|
            "" if not self.filters else "(%s)" % (
 | 
						|
                ", ".join("%s=%s" % (n, v.pattern) for n, v in self.filters.items())
 | 
						|
            ),
 | 
						|
            name_that_thing(self.consumer),
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
class RouteClass(Route):
 | 
						|
    """
 | 
						|
    Like Route, but targets a class-based consumer rather than a functional
 | 
						|
    one, meaning it looks for a (class) method called "channels()" on the
 | 
						|
    object rather than having a single channel passed in.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, consumer, **kwargs):
 | 
						|
        # Check the consumer provides a method_channels
 | 
						|
        consumer = self._resolve_consumer(consumer)
 | 
						|
        if not hasattr(consumer, "channel_names") or not callable(consumer.channel_names):
 | 
						|
            raise ValueError("The consumer passed to RouteClass has no valid channel_names method")
 | 
						|
        # Call super with list of channels
 | 
						|
        super(RouteClass, self).__init__(consumer.channel_names(), consumer, **kwargs)
 | 
						|
 | 
						|
 | 
						|
class Include(object):
 | 
						|
    """
 | 
						|
    Represents an inclusion of another routing list in another file.
 | 
						|
    Will automatically modify message match filters to add prefixes,
 | 
						|
    if specified.
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, routing, **kwargs):
 | 
						|
        self.routing = Router.resolve_routing(routing)
 | 
						|
        self.prefixes = {
 | 
						|
            name: re.compile(Router.normalise_re_arg(value))
 | 
						|
            for name, value in kwargs.items()
 | 
						|
        }
 | 
						|
 | 
						|
    def match(self, message):
 | 
						|
        """
 | 
						|
        Tries to match the message against our own prefixes, possibly modifying
 | 
						|
        what we send to included things, then tries all included items.
 | 
						|
        """
 | 
						|
        # Check our prefixes match. Do this against a copy of the message so
 | 
						|
        # we can write back any changed values.
 | 
						|
        message = message.copy()
 | 
						|
        call_args = {}
 | 
						|
        for name, prefix in self.prefixes.items():
 | 
						|
            if name not in message:
 | 
						|
                return None
 | 
						|
            value = Router.normalise_re_arg(message[name])
 | 
						|
            match = prefix.match(value)
 | 
						|
            # Any match failure means we pass
 | 
						|
            if match:
 | 
						|
                call_args.update(match.groupdict())
 | 
						|
                # Modify the message value to remove the part we matched on
 | 
						|
                message[name] = value[match.end():]
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
        # Alright, if we got this far our prefixes match. Try all of our
 | 
						|
        # included objects now.
 | 
						|
        for entry in self.routing:
 | 
						|
            match = entry.match(message)
 | 
						|
            if match is not None:
 | 
						|
                call_args.update(match[1])
 | 
						|
                return match[0], call_args
 | 
						|
        # Nothing matched :(
 | 
						|
        return None
 | 
						|
 | 
						|
    def channel_names(self):
 | 
						|
        """
 | 
						|
        Returns the channel names this route listens on
 | 
						|
        """
 | 
						|
        result = set()
 | 
						|
        for entry in self.routing:
 | 
						|
            result.update(entry.channel_names())
 | 
						|
        return result
 | 
						|
 | 
						|
 | 
						|
# Lowercase standard to match urls.py
 | 
						|
route = Route
 | 
						|
route_class = RouteClass
 | 
						|
include = Include
 |