Considerably improve routing code simplicity and shortcircuiting

This commit is contained in:
Andrew Godwin 2016-04-03 18:53:12 +02:00
parent 0e3c742a80
commit 5a22412c16
4 changed files with 79 additions and 92 deletions

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals
import copy
from .channel import Channel
@ -38,3 +39,13 @@ class Message(object):
def get(self, key, default=None):
return self.content.get(key, default)
def copy(self):
"""
Returns a safely content-mutable copy of this Message.
"""
return self.__class__(
copy.deepcopy(self.content),
self.channel.name,
self.channel_layer,
)

View File

@ -19,28 +19,17 @@ class Router(object):
"""
def __init__(self, routing):
# Resolve routing into a list if it's a dict or string
routing = self.resolve_routing(routing)
# Expand those entries recursively into a flat list of Routes
self.routing = []
for entry in routing:
try:
self.routing.extend(entry.expand_routes())
except AttributeError:
# It's not a valid route
raise ValueError("Encountered %r in routing config, which is not a valid route() or include()" % entry)
# Now go through that list and collect channel names into a set
self.channels = {
route.channel
for route in self.routing
}
# Use a blank include as the root item
self.root = Include(routing)
# Cache channel names
self.channels = self.root.channel_names()
def add_route(self, route):
"""
Adds a single raw Route to us at the end of the resolution list.
"""
self.routing.append(route)
self.channels.add(route.channel)
self.root.routing.append(route)
self.channels = self.root.channel_names()
def match(self, message):
"""
@ -50,11 +39,7 @@ class Router(object):
"""
# TODO: Maybe we can add some kind of caching in here if we can hash
# the message with only matchable keys faster than the search?
for route in self.routing:
match = route.match(message)
if match is not None:
return match
return None
return self.root.match(message)
def check_default(self, http_consumer=None):
"""
@ -152,43 +137,11 @@ class Route(object):
return None
return self.consumer, call_args
def expand_routes(self):
def channel_names(self):
"""
Expands this route into a list of just itself.
Returns the channel names this route listens on
"""
return [self]
def add_prefixes(self, prefixes):
"""
Returns a new Route with the given prefixes added to our filters.
"""
new_filters = {}
# Copy over our filters adding any prefixes
for name, value in self.filters.items():
if name in prefixes:
if not value.pattern.startswith("^"):
raise ValueError("Cannot add prefix for %s on %s as inner value does not start with ^" % (
name,
self,
))
if "$" in prefixes[name]:
raise ValueError("Cannot add prefix for %s on %s as prefix contains $ (end of line match)" % (
name,
self,
))
new_filters[name] = re.compile(prefixes[name] + value.pattern.lstrip("^"))
else:
new_filters[name] = value
# Now add any prefixes that are by themselves so they're still enforced
for name, prefix in prefixes.items():
if name not in new_filters:
new_filters[name] = prefix
# Return new copy
return self.__class__(
self.channel,
self.consumer,
**new_filters
)
return self.channel
def __str__(self):
return "%s %s -> %s" % (
@ -210,26 +163,49 @@ class Include(object):
def __init__(self, routing, **kwargs):
self.routing = Router.resolve_routing(routing)
self.prefixes = {
name: Router.normalise_re_arg(value)
name: re.compile(Router.normalise_re_arg(value))
for name, value in kwargs.items()
}
# Sanity check prefix regexes
for name, value in self.prefixes.items():
if not value.startswith("^"):
raise ValueError("Include prefix for %s must start with the ^ character." % name)
def expand_routes(self):
def match(self, message):
"""
Expands this Include into a list of routes, first recursively expanding
and then adding on prefixes to filters if specified.
Tries to match the message against our own prefixes, possibly modifying
what we send to included things, then tries all included items.
"""
# First, expand our own subset of routes, to get a list of Route objects
routes = []
# Check our prefixes match. Do this against a copy of the message so
# we can write back any changed values.
message = message.copy()
call_args = {}
for name, prefix in self.prefixes.items():
if name not in message:
return None
value = Router.normalise_re_arg(message[name])
match = prefix.match(value)
# Any match failure means we pass
if match:
call_args.update(match.groupdict())
# Modify the message value to remove the part we matched on
message[name] = value[match.end():]
else:
return None
# Alright, if we got this far our prefixes match. Try all of our
# included objects now.
for entry in self.routing:
routes.extend(entry.expand_routes())
# Then, go through those and add any prefixes we have.
routes = [route.add_prefixes(self.prefixes) for route in routes]
return routes
match = entry.match(message)
if match is not None:
call_args.update(match[1])
return match[0], call_args
# Nothing matched :(
return None
def channel_names(self):
"""
Returns the channel names this route listens on
"""
result = set()
for entry in self.routing:
result.union(entry.channel_names())
return result
# Lowercase standard to match urls.py

View File

@ -24,7 +24,7 @@ chatroom_routing = [
route("websocket.connect", consumer_3, path=r"^/mentions/$"),
]
chatroom_routing_noprefix = [
chatroom_routing_nolinestart = [
route("websocket.connect", consumer_2, path=r"/chat/(?P<room>[^/]+)/$"),
route("websocket.connect", consumer_3, path=r"/mentions/$"),
]
@ -207,6 +207,17 @@ class RoutingTests(SimpleTestCase):
consumer=consumer_3,
kwargs={"version": "1"},
)
# Check it works without the ^s too.
router = Router([
include("channels.tests.test_routing.chatroom_routing_nolinestart", path="/ws/v(?P<version>[0-9]+)"),
])
self.assertRoute(
router,
channel="websocket.connect",
content={"path": "/ws/v2/chat/django/"},
consumer=consumer_2,
kwargs={"version": "2", "room": "django"},
)
def test_positional_pattern(self):
"""
@ -217,20 +228,6 @@ class RoutingTests(SimpleTestCase):
route("http.request", consumer_1, path=r"^/chat/([^/]+)/$"),
])
def test_bad_include_prefix(self):
"""
Tests both failure cases of prefixes for includes - the include not
starting with ^, and the included filter not starting with ^.
"""
with self.assertRaises(ValueError):
Router([
include("channels.tests.test_routing.chatroom_routing", path="foobar"),
])
with self.assertRaises(ValueError):
Router([
include("channels.tests.test_routing.chatroom_routing_noprefix", path="^/foobar/"),
])
def test_mixed_unicode_bytes(self):
"""
Tests that having the message key be bytes and pattern unicode (or vice-versa)

View File

@ -504,12 +504,15 @@ routing our chat from above::
include(http_routing),
]
When Channels loads this routing, it appends any match keys together and
flattens out the routing, so the ``path`` match for ``chat_connect`` becomes
``^/chat/(?P<room>[a-zA-Z0-9_]+)/$``. If the include match
or the route match doesn't have the ``^`` character, it will refuse to append them
and error (you can still have matches without ``^`` in either, you just can't
ask Channels to combine them).
Channels will resolve the routing in order, short-circuiting around the
includes if one or more of their matches fails. You don't have to start with
the ``^`` symbol - we use Python's ``re.match`` function, which starts at the
start of a line anyway - but it's considered good practice.
When an include matches part of a message value, it chops off the bit of the
value it matched before passing it down to its routes or sub-includes, so you
can put the same routing under multiple includes with different prefixes if
you like.
Because these matches come through as keyword arguments, we could modify our
consumer above to use a room based on URL rather than username::