diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 457af6c88..3fd2e2e1e 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -21,3 +21,5 @@ HTTP_HEADER_ENCODING = 'iso-8859-1' # Default datetime input and output formats ISO_8601 = 'iso-8601' + +default_app_config = 'rest_framework.apps.RestFrameworkConfig' diff --git a/rest_framework/apps.py b/rest_framework/apps.py new file mode 100644 index 000000000..39cd350e5 --- /dev/null +++ b/rest_framework/apps.py @@ -0,0 +1,22 @@ +from django.apps import AppConfig +from django.conf import settings +from django.utils.translation import ugettext_lazy as _ + + +class RestFrameworkConfig(AppConfig): + name = 'rest_framework' + verbose_name = _("REST Framework") + + def ready(self): + """ + Try to auto-import any modules in INSTALLED_APPS. + + This lets us evaluate all of the @router.route('...') calls. + """ + from rest_framework.compat import importlib + for app in settings.INSTALLED_APPS: + for mod_name in ['api']: + try: + importlib.import_module('%s.%s' % (app, mod_name)) + except ImportError: + pass diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..3b900297c 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -23,6 +23,7 @@ from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import NoReverseMatch from rest_framework import exceptions, renderers, views +from rest_framework.compat import six from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator @@ -53,6 +54,52 @@ def flatten(list_of_lists): return itertools.chain(*list_of_lists) +class RouteTree: + """ + Stores a tree structure of routes. + """ + Node = namedtuple('Node', ['routes', 'value']) + + def __init__(self): + self.routes = {} + + def set(self, value, path, name): + """ + Set a value on a given path. Raises a KeyError if the value has already + been set or if the path doesn't exist. + """ + routes = self.routes + parent_path = [] + + for part in path: + if part not in routes: + on_path = (' on path %s' % parent_path) if len(parent_path) > 0 else '' + raise KeyError('Parent route "%s"%s was not registred' % (part, on_path)) + parent_path.append(part) + routes = routes[part].routes + + if name in routes: + on_path = (' on path %s' % path) if len(path) > 0 else '' + raise KeyError('Route "%s" already set%s' % (name, on_path)) + + routes[name] = self.Node({}, value) + + def get(self, path, name): + """ + Get a vector of (name, value) for each entry on the path and name. + Raises KeyError if path or name is invalid. + """ + routes = self.routes + vect = [] + + for part in path + [name]: + node = routes[part] + vect.append((part, node.value)) + routes = node.routes + + return vect + + class BaseRouter(object): def __init__(self): self.registry = [] @@ -127,6 +174,32 @@ class SimpleRouter(BaseRouter): self.trailing_slash = trailing_slash and '/' or '' super(SimpleRouter, self).__init__() + def route(self, *full_path, **kwargs): + """ + ViewSet class decorator for automatically registering a route: + + router = SimpleRouter() + + @router.route('parent') + class ParentViewSet(ViewSet): + pass + + @router.route('parent', 'child', base_name='children) + class ChildViewSet(ViewSet): + pass + """ + full_path = list(full_path) + assert len(full_path) > 0, 'Must provide a route prefix' + base_name = kwargs.pop('base_name', '-'.join(full_path)) + + def wrapper(viewset_cls): + self.register(wrapper._full_path, viewset_cls, wrapper._base_name) + return viewset_cls + + wrapper._full_path = full_path + wrapper._base_name = base_name + return wrapper + def get_default_base_name(self, viewset): """ If `base_name` is not specified, attempt to automatically determine @@ -234,13 +307,70 @@ class SimpleRouter(BaseRouter): lookup_value=lookup_value ) + def get_route_nodes_path(self, route_nodes): + """ + Given a vector of (prefix, viewset) tuples, get the resulting URL + regex. The URL kwarg name and lookup is determined by the prefix + and settings for each view. + + Duplicate prefixes on the path will be appended with a number. For + example, the path: + [('users', UserView), ('users', UserView), ('friends', UserView)] + would result in: + /users/(?P[^\]+)/users/(?P[^/]+)/friends + """ + encountered = {} + parts = [] + + for i, node in enumerate(route_nodes): + prefix, viewset = node + if i == len(route_nodes) - 1: + parts.append(prefix) + + else: + lookup_prefix = '%s_' % prefix + if prefix in encountered: + encountered[prefix] += 1 + lookup_prefix = '%s_%s_' % (prefix, encountered[prefix]) + else: + encountered[prefix] = 1 + parts.append('%s/%s' % (prefix, self.get_lookup_regex(viewset, lookup_prefix))) + + return '/'.join(parts) + + def _process_registry(self): + """ + New-style routes use a list of strings for the prefix. Normalize old- + style string prefixes and then sort based on path length and prefix. + + If a viewset was registered for all nodes that have nested nodes then + the result should construct a valid RouteTree when inserted in order. + """ + def normalize(entry): + prefix, viewset, base_name = entry + if isinstance(prefix, six.string_types): + prefix = [prefix] + return prefix, viewset, base_name + + def sort(entry): + full_path = entry[0] + prefix, path = full_path[-1], full_path[:-1] + return len(path), prefix + + return sorted(map(normalize, self.registry), key=sort) + def get_urls(self): """ Use the registered viewsets to generate a list of URL patterns. """ - ret = [] + urls = [] + route_tree = RouteTree() - for prefix, viewset, basename in self.registry: + for full_path, viewset, basename in self._process_registry(): + prefix, path = full_path[-1], full_path[:-1] + route_tree.set(viewset, path, prefix) + route_nodes = route_tree.get(path, prefix) + prefix = self.get_route_nodes_path(route_nodes) lookup = self.get_lookup_regex(viewset) routes = self.get_routes(viewset) @@ -260,9 +390,9 @@ class SimpleRouter(BaseRouter): view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) - ret.append(url(regex, view, name=name)) + urls.append(url(regex, view, name=name)) - return ret + return urls class DefaultRouter(SimpleRouter): diff --git a/tests/test_routers.py b/tests/test_routers.py index f45039f80..f0624a7a4 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -10,7 +10,7 @@ from django.test import TestCase, override_settings from rest_framework import permissions, serializers, viewsets from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response -from rest_framework.routers import DefaultRouter, SimpleRouter +from rest_framework.routers import DefaultRouter, RouteTree, SimpleRouter from rest_framework.test import APIRequestFactory factory = APIRequestFactory() @@ -89,6 +89,53 @@ class BasicViewSet(viewsets.ViewSet): return Response({'method': 'link2'}) +class TestRouteTree(TestCase): + def setUp(self): + self.tree = RouteTree() + + def test_set(self): + """ + Should set the value for the given name and path. + """ + self.tree.set(1, [], 'A') + self.tree.set(2, ['A'], 'B') + self.tree.set(3, ['A', 'B'], 'C') + + root = self.tree.routes + self.assertEqual(list(root.keys()), ['A']) + self.assertEqual(root['A'].value, 1) + self.assertEqual(list(root['A'].routes.keys()), ['B']) + self.assertEqual(root['A'].routes['B'].value, 2) + self.assertEqual(list(root['A'].routes['B'].routes.keys()), ['C']) + self.assertEqual(root['A'].routes['B'].routes['C'].value, 3) + + def test_set_invalid(self): + """ + A KeyError should be raised for an invalid name or path. + """ + self.tree.set(1, [], 'A') + self.tree.set(2, ['A'], 'B') + + with self.assertRaises(KeyError): + self.tree.set(5, ['A', 'B', 'C', 'C'], 'E') + + with self.assertRaises(KeyError): + self.tree.set(20, ['A'], 'B') + + def test_get(self): + """ + Should return a list of (name, value) for each node on + the path. + """ + self.tree.set(1, [], 'A') + self.tree.set(2, ['A'], 'B') + self.tree.set(3, ['A', 'B'], 'C') + + self.assertEqual(self.tree.get(['A', 'B'], 'C'), [ + ('A', 1), ('B', 2), ('C', 3) + ]) + + class TestSimpleRouter(TestCase): def setUp(self): self.router = SimpleRouter() @@ -112,6 +159,43 @@ class TestSimpleRouter(TestCase): for method in methods_map: self.assertEqual(route.mapping[method], endpoint) + def test_route_nodes_path(self): + """ + Should return a full prefix for the given route_nodes. Duplicate parent + lookups should have a count appended. + """ + tree = RouteTree() + tree.set(NoteViewSet, [], 'first') + tree.set(NoteViewSet, ['first'], 'first') + tree.set(NoteViewSet, ['first', 'first'], 'last') + + self.assertEqual( + self.router.get_route_nodes_path(tree.get([], 'first')), + r'first') + self.assertEqual( + self.router.get_route_nodes_path(tree.get(['first'], 'first')), + r'first/(?P[^/.]+)/first') + self.assertEqual( + self.router.get_route_nodes_path(tree.get(['first', 'first'], 'last')), + r'first/(?P[^/.]+)/first/(?P[^/.]+)/last') + + def test_route(self): + """ + Should allow registering ViewSets via a class decorator. + """ + @self.router.route('parent', 'child') + class ChildViewSet(NoteViewSet): + pass + + @self.router.route('parent') + class ParentViewSet(NoteViewSet): + pass + + self.assertEqual(self.router._process_registry(), [ + (['parent'], ParentViewSet, 'parent'), + (['parent', 'child'], ChildViewSet, 'parent-child') + ]) + @override_settings(ROOT_URLCONF='tests.test_routers') class TestRootView(TestCase):