mirror of
https://github.com/django/daphne.git
synced 2025-04-21 17:22:03 +03:00
Considerably improve routing code simplicity and shortcircuiting
This commit is contained in:
parent
0e3c742a80
commit
5a22412c16
|
@ -1,4 +1,5 @@
|
|||
from __future__ import unicode_literals
|
||||
import copy
|
||||
|
||||
from .channel import Channel
|
||||
|
||||
|
@ -38,3 +39,13 @@ class Message(object):
|
|||
|
||||
def get(self, key, default=None):
|
||||
return self.content.get(key, default)
|
||||
|
||||
def copy(self):
|
||||
"""
|
||||
Returns a safely content-mutable copy of this Message.
|
||||
"""
|
||||
return self.__class__(
|
||||
copy.deepcopy(self.content),
|
||||
self.channel.name,
|
||||
self.channel_layer,
|
||||
)
|
||||
|
|
|
@ -19,28 +19,17 @@ class Router(object):
|
|||
"""
|
||||
|
||||
def __init__(self, routing):
|
||||
# Resolve routing into a list if it's a dict or string
|
||||
routing = self.resolve_routing(routing)
|
||||
# Expand those entries recursively into a flat list of Routes
|
||||
self.routing = []
|
||||
for entry in routing:
|
||||
try:
|
||||
self.routing.extend(entry.expand_routes())
|
||||
except AttributeError:
|
||||
# It's not a valid route
|
||||
raise ValueError("Encountered %r in routing config, which is not a valid route() or include()" % entry)
|
||||
# Now go through that list and collect channel names into a set
|
||||
self.channels = {
|
||||
route.channel
|
||||
for route in self.routing
|
||||
}
|
||||
# Use a blank include as the root item
|
||||
self.root = Include(routing)
|
||||
# Cache channel names
|
||||
self.channels = self.root.channel_names()
|
||||
|
||||
def add_route(self, route):
|
||||
"""
|
||||
Adds a single raw Route to us at the end of the resolution list.
|
||||
"""
|
||||
self.routing.append(route)
|
||||
self.channels.add(route.channel)
|
||||
self.root.routing.append(route)
|
||||
self.channels = self.root.channel_names()
|
||||
|
||||
def match(self, message):
|
||||
"""
|
||||
|
@ -50,11 +39,7 @@ class Router(object):
|
|||
"""
|
||||
# TODO: Maybe we can add some kind of caching in here if we can hash
|
||||
# the message with only matchable keys faster than the search?
|
||||
for route in self.routing:
|
||||
match = route.match(message)
|
||||
if match is not None:
|
||||
return match
|
||||
return None
|
||||
return self.root.match(message)
|
||||
|
||||
def check_default(self, http_consumer=None):
|
||||
"""
|
||||
|
@ -152,43 +137,11 @@ class Route(object):
|
|||
return None
|
||||
return self.consumer, call_args
|
||||
|
||||
def expand_routes(self):
|
||||
def channel_names(self):
|
||||
"""
|
||||
Expands this route into a list of just itself.
|
||||
Returns the channel names this route listens on
|
||||
"""
|
||||
return [self]
|
||||
|
||||
def add_prefixes(self, prefixes):
|
||||
"""
|
||||
Returns a new Route with the given prefixes added to our filters.
|
||||
"""
|
||||
new_filters = {}
|
||||
# Copy over our filters adding any prefixes
|
||||
for name, value in self.filters.items():
|
||||
if name in prefixes:
|
||||
if not value.pattern.startswith("^"):
|
||||
raise ValueError("Cannot add prefix for %s on %s as inner value does not start with ^" % (
|
||||
name,
|
||||
self,
|
||||
))
|
||||
if "$" in prefixes[name]:
|
||||
raise ValueError("Cannot add prefix for %s on %s as prefix contains $ (end of line match)" % (
|
||||
name,
|
||||
self,
|
||||
))
|
||||
new_filters[name] = re.compile(prefixes[name] + value.pattern.lstrip("^"))
|
||||
else:
|
||||
new_filters[name] = value
|
||||
# Now add any prefixes that are by themselves so they're still enforced
|
||||
for name, prefix in prefixes.items():
|
||||
if name not in new_filters:
|
||||
new_filters[name] = prefix
|
||||
# Return new copy
|
||||
return self.__class__(
|
||||
self.channel,
|
||||
self.consumer,
|
||||
**new_filters
|
||||
)
|
||||
return self.channel
|
||||
|
||||
def __str__(self):
|
||||
return "%s %s -> %s" % (
|
||||
|
@ -210,26 +163,49 @@ class Include(object):
|
|||
def __init__(self, routing, **kwargs):
|
||||
self.routing = Router.resolve_routing(routing)
|
||||
self.prefixes = {
|
||||
name: Router.normalise_re_arg(value)
|
||||
name: re.compile(Router.normalise_re_arg(value))
|
||||
for name, value in kwargs.items()
|
||||
}
|
||||
# Sanity check prefix regexes
|
||||
for name, value in self.prefixes.items():
|
||||
if not value.startswith("^"):
|
||||
raise ValueError("Include prefix for %s must start with the ^ character." % name)
|
||||
|
||||
def expand_routes(self):
|
||||
def match(self, message):
|
||||
"""
|
||||
Expands this Include into a list of routes, first recursively expanding
|
||||
and then adding on prefixes to filters if specified.
|
||||
Tries to match the message against our own prefixes, possibly modifying
|
||||
what we send to included things, then tries all included items.
|
||||
"""
|
||||
# First, expand our own subset of routes, to get a list of Route objects
|
||||
routes = []
|
||||
# Check our prefixes match. Do this against a copy of the message so
|
||||
# we can write back any changed values.
|
||||
message = message.copy()
|
||||
call_args = {}
|
||||
for name, prefix in self.prefixes.items():
|
||||
if name not in message:
|
||||
return None
|
||||
value = Router.normalise_re_arg(message[name])
|
||||
match = prefix.match(value)
|
||||
# Any match failure means we pass
|
||||
if match:
|
||||
call_args.update(match.groupdict())
|
||||
# Modify the message value to remove the part we matched on
|
||||
message[name] = value[match.end():]
|
||||
else:
|
||||
return None
|
||||
# Alright, if we got this far our prefixes match. Try all of our
|
||||
# included objects now.
|
||||
for entry in self.routing:
|
||||
routes.extend(entry.expand_routes())
|
||||
# Then, go through those and add any prefixes we have.
|
||||
routes = [route.add_prefixes(self.prefixes) for route in routes]
|
||||
return routes
|
||||
match = entry.match(message)
|
||||
if match is not None:
|
||||
call_args.update(match[1])
|
||||
return match[0], call_args
|
||||
# Nothing matched :(
|
||||
return None
|
||||
|
||||
def channel_names(self):
|
||||
"""
|
||||
Returns the channel names this route listens on
|
||||
"""
|
||||
result = set()
|
||||
for entry in self.routing:
|
||||
result.union(entry.channel_names())
|
||||
return result
|
||||
|
||||
|
||||
# Lowercase standard to match urls.py
|
||||
|
|
|
@ -24,7 +24,7 @@ chatroom_routing = [
|
|||
route("websocket.connect", consumer_3, path=r"^/mentions/$"),
|
||||
]
|
||||
|
||||
chatroom_routing_noprefix = [
|
||||
chatroom_routing_nolinestart = [
|
||||
route("websocket.connect", consumer_2, path=r"/chat/(?P<room>[^/]+)/$"),
|
||||
route("websocket.connect", consumer_3, path=r"/mentions/$"),
|
||||
]
|
||||
|
@ -207,6 +207,17 @@ class RoutingTests(SimpleTestCase):
|
|||
consumer=consumer_3,
|
||||
kwargs={"version": "1"},
|
||||
)
|
||||
# Check it works without the ^s too.
|
||||
router = Router([
|
||||
include("channels.tests.test_routing.chatroom_routing_nolinestart", path="/ws/v(?P<version>[0-9]+)"),
|
||||
])
|
||||
self.assertRoute(
|
||||
router,
|
||||
channel="websocket.connect",
|
||||
content={"path": "/ws/v2/chat/django/"},
|
||||
consumer=consumer_2,
|
||||
kwargs={"version": "2", "room": "django"},
|
||||
)
|
||||
|
||||
def test_positional_pattern(self):
|
||||
"""
|
||||
|
@ -217,20 +228,6 @@ class RoutingTests(SimpleTestCase):
|
|||
route("http.request", consumer_1, path=r"^/chat/([^/]+)/$"),
|
||||
])
|
||||
|
||||
def test_bad_include_prefix(self):
|
||||
"""
|
||||
Tests both failure cases of prefixes for includes - the include not
|
||||
starting with ^, and the included filter not starting with ^.
|
||||
"""
|
||||
with self.assertRaises(ValueError):
|
||||
Router([
|
||||
include("channels.tests.test_routing.chatroom_routing", path="foobar"),
|
||||
])
|
||||
with self.assertRaises(ValueError):
|
||||
Router([
|
||||
include("channels.tests.test_routing.chatroom_routing_noprefix", path="^/foobar/"),
|
||||
])
|
||||
|
||||
def test_mixed_unicode_bytes(self):
|
||||
"""
|
||||
Tests that having the message key be bytes and pattern unicode (or vice-versa)
|
||||
|
|
|
@ -504,12 +504,15 @@ routing our chat from above::
|
|||
include(http_routing),
|
||||
]
|
||||
|
||||
When Channels loads this routing, it appends any match keys together and
|
||||
flattens out the routing, so the ``path`` match for ``chat_connect`` becomes
|
||||
``^/chat/(?P<room>[a-zA-Z0-9_]+)/$``. If the include match
|
||||
or the route match doesn't have the ``^`` character, it will refuse to append them
|
||||
and error (you can still have matches without ``^`` in either, you just can't
|
||||
ask Channels to combine them).
|
||||
Channels will resolve the routing in order, short-circuiting around the
|
||||
includes if one or more of their matches fails. You don't have to start with
|
||||
the ``^`` symbol - we use Python's ``re.match`` function, which starts at the
|
||||
start of a line anyway - but it's considered good practice.
|
||||
|
||||
When an include matches part of a message value, it chops off the bit of the
|
||||
value it matched before passing it down to its routes or sub-includes, so you
|
||||
can put the same routing under multiple includes with different prefixes if
|
||||
you like.
|
||||
|
||||
Because these matches come through as keyword arguments, we could modify our
|
||||
consumer above to use a room based on URL rather than username::
|
||||
|
|
Loading…
Reference in New Issue
Block a user