From 8f9f8b31d3a1d00ad0446865e9d737ccd6dc6448 Mon Sep 17 00:00:00 2001 From: Benjamin Date: Tue, 9 Aug 2016 01:41:22 -0400 Subject: [PATCH 1/6] Added RouteTree to routers, refactored SimpleRouter to support new-style and old-style route paths --- rest_framework/routers.py | 109 ++++++++++++++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 4 deletions(-) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..850bb408e 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -53,6 +53,50 @@ def flatten(list_of_lists): return itertools.chain(*list_of_lists) +class RouteTree: + """ + Stores a tree structure of routes. + """ + Node = namedtuple('Node', ['routes', 'value']) + + def __init__(self): + self.routes = {} + + def set(self, value, path, name): + """ + Set a value on a given path. Raises a KeyError if the value has already + been set or if the path doesn't exist. + """ + routes = self.routes + + for part in path: + if part not in routes: + # TODO: invalid path error message + raise KeyError('') + routes = routes[path].routes + + if name in routes: + # TODO: route name already exists error message + raise KeyError('') + + routes[name] = self.Node({}, value) + + def get(self, path, name): + """ + Get a vector of (name, value) for each entry on the path and name. + Raises KeyError if path or name is invalid. + """ + routes = self.routes + vect = [] + + for part in path + [name]: + node = routes[part] + vect.append((part, node.value)) + routes = node.routes + + return vect + + class BaseRouter(object): def __init__(self): self.registry = [] @@ -234,13 +278,70 @@ class SimpleRouter(BaseRouter): lookup_value=lookup_value ) + def get_route_nodes_path(self, route_nodes): + """ + Given a vector of (prefix, viewset) tuples, get the resulting URL + regex. The URL kwarg name and lookup is determined by the prefix + and settings for each view. + + Duplicate prefixes on the path will be appended with a number. For + example, the path: + [('users', UserView), ('users', UserView), ('friends', UserView)] + would result in: + /users/(?P[^\]+)/users/(?P[^/]+)/friends + """ + encountered = {} + parts = [] + + for i, node in enumerate(route_nodes): + prefix, viewset = node + if i == len(route_nodes) - 1: + parts.append(prefix) + + else: + lookup_prefix = '%s_' % prefix + if prefix in encountered: + encountered[prefix] += 1 + lookup_prefix = '%s_%s_' % (prefix, encountered[prefix]) + else: + encountered[prefix] = 1 + parts.append('%s/%s' % (prefix, self.get_lookup_regex(viewset, lookup_prefix))) + + return '/'.join(parts) + + def _process_registry(self): + """ + New-style routes use a list of strings for the prefix. Normalize old- + style string prefixes and then sort based on path length and prefix. + + If a viewset was registered for all nodes that have nested nodes then + the result should construct a valid RouteTree when inserted in order. + """ + def normalize(entry): + prefix, viewset, base_name = entry + if isinstance(prefix, str): + prefix = [prefix] + return prefix, viewset, base_name + + def sort(entry): + full_path = entry[0] + prefix, path = full_path[-1], full_path[:-1] + return len(path), prefix + + return sorted(map(normalize, self.registry), key=sort) + def get_urls(self): """ Use the registered viewsets to generate a list of URL patterns. """ - ret = [] + urls = [] + route_tree = RouteTree() - for prefix, viewset, basename in self.registry: + for full_path, viewset, basename in self._process_registry(): + prefix, path = full_path[-1], full_path[:-1] + route_tree.set(viewset, path, prefix) + route_nodes = route_tree.get(path, prefix) + prefix = self.get_route_nodes_path(route_nodes) lookup = self.get_lookup_regex(viewset) routes = self.get_routes(viewset) @@ -260,9 +361,9 @@ class SimpleRouter(BaseRouter): view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) - ret.append(url(regex, view, name=name)) + urls.append(url(regex, view, name=name)) - return ret + return urls class DefaultRouter(SimpleRouter): From 5cbd737af3b530a4489d441c4b43a3dfbef29230 Mon Sep 17 00:00:00 2001 From: Benjamin Dummer Date: Tue, 9 Aug 2016 02:16:40 -0400 Subject: [PATCH 2/6] Added tests for RouteTree, improved error messages, small bug fix --- rest_framework/routers.py | 12 ++++++---- tests/test_routers.py | 49 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 850bb408e..7bbfc12ec 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -68,16 +68,18 @@ class RouteTree: been set or if the path doesn't exist. """ routes = self.routes + parent_path = [] for part in path: if part not in routes: - # TODO: invalid path error message - raise KeyError('') - routes = routes[path].routes + on_path = (' on path %s' % parent_path) if len(parent_path) > 0 else '' + raise KeyError('Parent route "%s"%s was not registred' % (part, on_path)) + parent_path.append(part) + routes = routes[part].routes if name in routes: - # TODO: route name already exists error message - raise KeyError('') + on_path = (' on path %s' % path) if len(path) > 0 else '' + raise KeyError('Route "%s" already set%s' % (name, on_path)) routes[name] = self.Node({}, value) diff --git a/tests/test_routers.py b/tests/test_routers.py index f45039f80..351a0f61e 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -10,7 +10,7 @@ from django.test import TestCase, override_settings from rest_framework import permissions, serializers, viewsets from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response -from rest_framework.routers import DefaultRouter, SimpleRouter +from rest_framework.routers import DefaultRouter, RouteTree, SimpleRouter from rest_framework.test import APIRequestFactory factory = APIRequestFactory() @@ -89,6 +89,53 @@ class BasicViewSet(viewsets.ViewSet): return Response({'method': 'link2'}) +class TestRouteTree(TestCase): + def setUp(self): + self.tree = RouteTree() + + def test_set(self): + """ + Should set the value for the given name and path. + """ + self.tree.set(1, [], 'A') + self.tree.set(2, ['A'], 'B') + self.tree.set(3, ['A', 'B'], 'C') + + root = self.tree.routes + self.assertEqual(list(root.keys()), ['A']) + self.assertEqual(root['A'].value, 1) + self.assertEqual(list(root['A'].routes.keys()), ['B']) + self.assertEqual(root['A'].routes['B'].value, 2) + self.assertEqual(list(root['A'].routes['B'].routes.keys()), ['C']) + self.assertEqual(root['A'].routes['B'].routes['C'].value, 3) + + def test_set_invalid(self): + """ + A KeyError should be raised for an invalid name or path. + """ + self.tree.set(1, [], 'A') + self.tree.set(2, ['A'], 'B') + + with self.assertRaises(KeyError): + self.tree.set(5, ['A', 'B', 'C', 'C'], 'E') + + with self.assertRaises(KeyError): + self.tree.set(20, ['A'], 'B') + + def test_get(self): + """ + Should return a list of (name, value) for each node on + the path. + """ + self.tree.set(1, [], 'A') + self.tree.set(2, ['A'], 'B') + self.tree.set(3, ['A', 'B'], 'C') + + self.assertEqual(self.tree.get(['A', 'B'], 'C'), [ + ('A', 1), ('B', 2), ('C', 3) + ]) + + class TestSimpleRouter(TestCase): def setUp(self): self.router = SimpleRouter() From 6a7d3937dabde2eb0b5a35977ef98691b5f73a98 Mon Sep 17 00:00:00 2001 From: Benjamin Dummer Date: Tue, 9 Aug 2016 02:40:14 -0400 Subject: [PATCH 3/6] Completed tests for routers module --- tests/test_routers.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/tests/test_routers.py b/tests/test_routers.py index 351a0f61e..5d4bc692e 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -159,6 +159,26 @@ class TestSimpleRouter(TestCase): for method in methods_map: self.assertEqual(route.mapping[method], endpoint) + def test_route_nodes_path(self): + """ + Should return a full prefix for the given route_nodes. Duplicate parent + lookups should have a count appended. + """ + tree = RouteTree() + tree.set(NoteViewSet, [], 'first') + tree.set(NoteViewSet, ['first'], 'first') + tree.set(NoteViewSet, ['first', 'first'], 'last') + + self.assertEqual( + self.router.get_route_nodes_path(tree.get([], 'first')), + r'first') + self.assertEqual( + self.router.get_route_nodes_path(tree.get(['first'], 'first')), + r'first/(?P[^/.]+)/first') + self.assertEqual( + self.router.get_route_nodes_path(tree.get(['first', 'first'], 'last')), + r'first/(?P[^/.]+)/first/(?P[^/.]+)/last') + @override_settings(ROOT_URLCONF='tests.test_routers') class TestRootView(TestCase): From b914f5fb41bca6c66fbf393834e3dc512ee8cc22 Mon Sep 17 00:00:00 2001 From: Benjamin Dummer Date: Tue, 9 Aug 2016 03:05:35 -0400 Subject: [PATCH 4/6] Added @router.route() decorator and tests --- rest_framework/__init__.py | 2 ++ rest_framework/apps.py | 23 +++++++++++++++++++++++ rest_framework/routers.py | 26 ++++++++++++++++++++++++++ tests/test_routers.py | 17 +++++++++++++++++ 4 files changed, 68 insertions(+) create mode 100644 rest_framework/apps.py diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 19f83ecab..31338d7c5 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -21,3 +21,5 @@ HTTP_HEADER_ENCODING = 'iso-8859-1' # Default datetime input and output formats ISO_8601 = 'iso-8601' + +default_app_config = 'rest_framework.apps.RestFrameworkConfig' diff --git a/rest_framework/apps.py b/rest_framework/apps.py new file mode 100644 index 000000000..be9bd809f --- /dev/null +++ b/rest_framework/apps.py @@ -0,0 +1,23 @@ +from importlib import import_module + +from django.apps import AppConfig +from django.conf import settings +from django.utils.translation import ugettext_lazy as _ + + +class RestFrameworkConfig(AppConfig): + name = 'rest_framework' + verbose_name = _("REST Framework") + + def ready(self): + """ + Try to auto-import any modules in INSTALLED_APPS. + + This lets us evaluate all of the @router.route('...') calls. + """ + for app in settings.INSTALLED_APPS: + for mod_name in ['api']: + try: + import_module('%s.%s' % (app, mod_name)) + except ImportError: + pass diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 7bbfc12ec..5b231a2f4 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -173,6 +173,32 @@ class SimpleRouter(BaseRouter): self.trailing_slash = trailing_slash and '/' or '' super(SimpleRouter, self).__init__() + def route(self, *full_path, **kwargs): + """ + ViewSet class decorator for automatically registering a route: + + router = SimpleRouter() + + @router.route('parent') + class ParentViewSet(ViewSet): + pass + + @router.route('parent', 'child', base_name='children) + class ChildViewSet(ViewSet): + pass + """ + full_path = list(full_path) + assert len(full_path) > 0, 'Must provide a route prefix' + base_name = kwargs.pop('base_name', '-'.join(full_path)) + + def wrapper(viewset_cls): + self.register(wrapper._full_path, viewset_cls, wrapper._base_name) + return viewset_cls + + wrapper._full_path = full_path + wrapper._base_name = base_name + return wrapper + def get_default_base_name(self, viewset): """ If `base_name` is not specified, attempt to automatically determine diff --git a/tests/test_routers.py b/tests/test_routers.py index 5d4bc692e..f0624a7a4 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -179,6 +179,23 @@ class TestSimpleRouter(TestCase): self.router.get_route_nodes_path(tree.get(['first', 'first'], 'last')), r'first/(?P[^/.]+)/first/(?P[^/.]+)/last') + def test_route(self): + """ + Should allow registering ViewSets via a class decorator. + """ + @self.router.route('parent', 'child') + class ChildViewSet(NoteViewSet): + pass + + @self.router.route('parent') + class ParentViewSet(NoteViewSet): + pass + + self.assertEqual(self.router._process_registry(), [ + (['parent'], ParentViewSet, 'parent'), + (['parent', 'child'], ChildViewSet, 'parent-child') + ]) + @override_settings(ROOT_URLCONF='tests.test_routers') class TestRootView(TestCase): From aa36e6722f7acbd5792372879a89145c3aa187ec Mon Sep 17 00:00:00 2001 From: Benjamin Dummer Date: Tue, 9 Aug 2016 03:50:40 -0400 Subject: [PATCH 5/6] Python version compat adjustments --- rest_framework/apps.py | 6 +++--- rest_framework/routers.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/rest_framework/apps.py b/rest_framework/apps.py index be9bd809f..8f53af005 100644 --- a/rest_framework/apps.py +++ b/rest_framework/apps.py @@ -1,9 +1,9 @@ -from importlib import import_module - from django.apps import AppConfig from django.conf import settings from django.utils.translation import ugettext_lazy as _ +from rest_framework.compat import importlib + class RestFrameworkConfig(AppConfig): name = 'rest_framework' @@ -18,6 +18,6 @@ class RestFrameworkConfig(AppConfig): for app in settings.INSTALLED_APPS: for mod_name in ['api']: try: - import_module('%s.%s' % (app, mod_name)) + importlib.import_module('%s.%s' % (app, mod_name)) except ImportError: pass diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 5b231a2f4..3b900297c 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -23,6 +23,7 @@ from django.core.exceptions import ImproperlyConfigured from django.core.urlresolvers import NoReverseMatch from rest_framework import exceptions, renderers, views +from rest_framework.compat import six from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator @@ -347,7 +348,7 @@ class SimpleRouter(BaseRouter): """ def normalize(entry): prefix, viewset, base_name = entry - if isinstance(prefix, str): + if isinstance(prefix, six.string_types): prefix = [prefix] return prefix, viewset, base_name From 4a32f6857fa1ae86ee94d6829eefbdc249c52b1c Mon Sep 17 00:00:00 2001 From: Benjamin Dummer Date: Tue, 9 Aug 2016 03:55:32 -0400 Subject: [PATCH 6/6] Issue with newer Django versions --- rest_framework/apps.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/rest_framework/apps.py b/rest_framework/apps.py index 8f53af005..39cd350e5 100644 --- a/rest_framework/apps.py +++ b/rest_framework/apps.py @@ -2,8 +2,6 @@ from django.apps import AppConfig from django.conf import settings from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import importlib - class RestFrameworkConfig(AppConfig): name = 'rest_framework' @@ -15,6 +13,7 @@ class RestFrameworkConfig(AppConfig): This lets us evaluate all of the @router.route('...') calls. """ + from rest_framework.compat import importlib for app in settings.INSTALLED_APPS: for mod_name in ['api']: try: