mirror of
				https://github.com/django/daphne.git
				synced 2025-11-01 00:07:26 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			265 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			265 lines
		
	
	
		
			9.3 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import unicode_literals
 | |
| 
 | |
| import importlib
 | |
| import re
 | |
| 
 | |
| from django.core.exceptions import ImproperlyConfigured
 | |
| from django.utils import six
 | |
| 
 | |
| from .utils import name_that_thing
 | |
| 
 | |
| 
 | |
| class Router(object):
 | |
|     """
 | |
|     Manages the available consumers in the project and which channels they
 | |
|     listen to.
 | |
| 
 | |
|     Generally this is attached to a backend instance as ".router"
 | |
| 
 | |
|     Anything can be a routable object as long as it provides a match()
 | |
|     method that either returns (callable, kwargs) or None.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, routing):
 | |
|         # Use a blank include as the root item
 | |
|         self.root = Include(routing)
 | |
|         # Cache channel names
 | |
|         self.channels = self.root.channel_names()
 | |
| 
 | |
|     def add_route(self, route):
 | |
|         """
 | |
|         Adds a single raw Route to us at the end of the resolution list.
 | |
|         """
 | |
|         self.root.routing.append(route)
 | |
|         self.channels = self.root.channel_names()
 | |
| 
 | |
|     def match(self, message):
 | |
|         """
 | |
|         Runs through our routing and tries to find a consumer that matches
 | |
|         the message/channel. Returns (consumer, extra_kwargs) if it does,
 | |
|         and None if it doesn't.
 | |
|         """
 | |
|         # TODO: Maybe we can add some kind of caching in here if we can hash
 | |
|         # the message with only matchable keys faster than the search?
 | |
|         return self.root.match(message)
 | |
| 
 | |
|     def check_default(self, http_consumer=None):
 | |
|         """
 | |
|         Adds default handlers for Django's default handling of channels.
 | |
|         """
 | |
|         # We just add the default Django route to the bottom; if the user
 | |
|         # has defined another http.request handler, it'll get hit first and run.
 | |
|         # Inner import here to avoid circular import; this function only gets
 | |
|         # called once, thankfully.
 | |
|         from .handler import ViewConsumer
 | |
|         self.add_route(Route("http.request", http_consumer or ViewConsumer()))
 | |
|         # We also add a no-op websocket.connect consumer to the bottom, as the
 | |
|         # spec requires that this is consumed, but Channels does not. Any user
 | |
|         # consumer will override this one. Same for websocket.receive.
 | |
|         self.add_route(Route("websocket.connect", connect_consumer))
 | |
|         self.add_route(Route("websocket.receive", null_consumer))
 | |
|         self.add_route(Route("websocket.disconnect", null_consumer))
 | |
| 
 | |
|     @classmethod
 | |
|     def resolve_routing(cls, routing):
 | |
|         """
 | |
|         Takes a routing - if it's a string, it imports it, and if it's a
 | |
|         dict, converts it to a list of route()s. Used by this class and Include.
 | |
|         """
 | |
|         # If the routing was a string, import it
 | |
|         if isinstance(routing, six.string_types):
 | |
|             module_name, variable_name = routing.rsplit(".", 1)
 | |
|             try:
 | |
|                 routing = getattr(importlib.import_module(module_name), variable_name)
 | |
|             except (ImportError, AttributeError) as e:
 | |
|                 raise ImproperlyConfigured("Cannot import channel routing %r: %s" % (routing, e))
 | |
|         # If the routing is a dict, convert it
 | |
|         if isinstance(routing, dict):
 | |
|             routing = [
 | |
|                 Route(channel, consumer)
 | |
|                 for channel, consumer in routing.items()
 | |
|             ]
 | |
|         return routing
 | |
| 
 | |
|     @classmethod
 | |
|     def normalise_re_arg(cls, value):
 | |
|         """
 | |
|         Normalises regular expression patterns and string inputs to Unicode.
 | |
|         """
 | |
|         if isinstance(value, six.binary_type):
 | |
|             return value.decode("ascii")
 | |
|         else:
 | |
|             return value
 | |
| 
 | |
| 
 | |
| class Route(object):
 | |
|     """
 | |
|     Represents a route to a single consumer, with a channel name
 | |
|     and optional message parameter matching.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, channels, consumer, **kwargs):
 | |
|         # Get channels, make sure it's a list of unicode strings
 | |
|         if isinstance(channels, six.string_types):
 | |
|             channels = [channels]
 | |
|         self.channels = [
 | |
|             channel.decode("ascii") if isinstance(channel, six.binary_type) else channel
 | |
|             for channel in channels
 | |
|         ]
 | |
|         # Get consumer, optionally importing it
 | |
|         self.consumer = self._resolve_consumer(consumer)
 | |
|         # Compile filter regexes up front
 | |
|         self.filters = {
 | |
|             name: re.compile(Router.normalise_re_arg(value))
 | |
|             for name, value in kwargs.items()
 | |
|         }
 | |
|         # Check filters don't use positional groups
 | |
|         for name, regex in self.filters.items():
 | |
|             if regex.groups != len(regex.groupindex):
 | |
|                 raise ValueError(
 | |
|                     "Filter for %s on %s contains positional groups; "
 | |
|                     "only named groups are allowed." % (
 | |
|                         name,
 | |
|                         self,
 | |
|                     )
 | |
|                 )
 | |
| 
 | |
|     def _resolve_consumer(self, consumer):
 | |
|         """
 | |
|         Turns the consumer from a string into an object if it's a string,
 | |
|         passes it through otherwise.
 | |
|         """
 | |
|         if isinstance(consumer, six.string_types):
 | |
|             module_name, variable_name = consumer.rsplit(".", 1)
 | |
|             try:
 | |
|                 consumer = getattr(importlib.import_module(module_name), variable_name)
 | |
|             except (ImportError, AttributeError):
 | |
|                 raise ImproperlyConfigured("Cannot import consumer %r" % consumer)
 | |
|         return consumer
 | |
| 
 | |
|     def match(self, message):
 | |
|         """
 | |
|         Checks to see if we match the Message object. Returns
 | |
|         (consumer, kwargs dict) if it matches, None otherwise
 | |
|         """
 | |
|         # Check for channel match first of all
 | |
|         if message.channel.name not in self.channels:
 | |
|             return None
 | |
|         # Check each message filter and build consumer kwargs as we go
 | |
|         call_args = {}
 | |
|         for name, value in self.filters.items():
 | |
|             if name not in message:
 | |
|                 return None
 | |
|             match = value.match(Router.normalise_re_arg(message[name]))
 | |
|             # Any match failure means we pass
 | |
|             if match:
 | |
|                 call_args.update(match.groupdict())
 | |
|             else:
 | |
|                 return None
 | |
|         return self.consumer, call_args
 | |
| 
 | |
|     def channel_names(self):
 | |
|         """
 | |
|         Returns the channel names this route listens on
 | |
|         """
 | |
|         return set(self.channels)
 | |
| 
 | |
|     def __str__(self):
 | |
|         return "%s %s -> %s" % (
 | |
|             "/".join(self.channels),
 | |
|             "" if not self.filters else "(%s)" % (
 | |
|                 ", ".join("%s=%s" % (n, v.pattern) for n, v in self.filters.items())
 | |
|             ),
 | |
|             name_that_thing(self.consumer),
 | |
|         )
 | |
| 
 | |
| 
 | |
| class RouteClass(Route):
 | |
|     """
 | |
|     Like Route, but targets a class-based consumer rather than a functional
 | |
|     one, meaning it looks for a (class) method called "channels()" on the
 | |
|     object rather than having a single channel passed in.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, consumer, **kwargs):
 | |
|         # Check the consumer provides a method_channels
 | |
|         consumer = self._resolve_consumer(consumer)
 | |
|         if not hasattr(consumer, "channel_names") or not callable(consumer.channel_names):
 | |
|             raise ValueError("The consumer passed to RouteClass has no valid channel_names method")
 | |
|         # Call super with list of channels
 | |
|         super(RouteClass, self).__init__(consumer.channel_names(), consumer, **kwargs)
 | |
| 
 | |
| 
 | |
| class Include(object):
 | |
|     """
 | |
|     Represents an inclusion of another routing list in another file.
 | |
|     Will automatically modify message match filters to add prefixes,
 | |
|     if specified.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, routing, **kwargs):
 | |
|         self.routing = Router.resolve_routing(routing)
 | |
|         self.prefixes = {
 | |
|             name: re.compile(Router.normalise_re_arg(value))
 | |
|             for name, value in kwargs.items()
 | |
|         }
 | |
| 
 | |
|     def match(self, message):
 | |
|         """
 | |
|         Tries to match the message against our own prefixes, possibly modifying
 | |
|         what we send to included things, then tries all included items.
 | |
|         """
 | |
|         # Check our prefixes match. Do this against a copy of the message so
 | |
|         # we can write back any changed values.
 | |
|         message = message.copy()
 | |
|         call_args = {}
 | |
|         for name, prefix in self.prefixes.items():
 | |
|             if name not in message:
 | |
|                 return None
 | |
|             value = Router.normalise_re_arg(message[name])
 | |
|             match = prefix.match(value)
 | |
|             # Any match failure means we pass
 | |
|             if match:
 | |
|                 call_args.update(match.groupdict())
 | |
|                 # Modify the message value to remove the part we matched on
 | |
|                 message[name] = value[match.end():]
 | |
|             else:
 | |
|                 return None
 | |
|         # Alright, if we got this far our prefixes match. Try all of our
 | |
|         # included objects now.
 | |
|         for entry in self.routing:
 | |
|             match = entry.match(message)
 | |
|             if match is not None:
 | |
|                 call_args.update(match[1])
 | |
|                 return match[0], call_args
 | |
|         # Nothing matched :(
 | |
|         return None
 | |
| 
 | |
|     def channel_names(self):
 | |
|         """
 | |
|         Returns the channel names this route listens on
 | |
|         """
 | |
|         result = set()
 | |
|         for entry in self.routing:
 | |
|             result.update(entry.channel_names())
 | |
|         return result
 | |
| 
 | |
| 
 | |
| def null_consumer(*args, **kwargs):
 | |
|     """
 | |
|     Standard no-op consumer.
 | |
|     """
 | |
| 
 | |
| 
 | |
| def connect_consumer(message, *args, **kwargs):
 | |
|     """
 | |
|     Accept-all-connections websocket.connect consumer
 | |
|     """
 | |
|     message.reply_channel.send({"accept": True})
 | |
| 
 | |
| 
 | |
| # Lowercase standard to match urls.py
 | |
| route = Route
 | |
| route_class = RouteClass
 | |
| include = Include
 |