This commit is contained in:
Benjamin Dummer 2016-10-02 06:56:05 +00:00 committed by GitHub
commit a0b03b744b
4 changed files with 243 additions and 5 deletions

View File

@ -21,3 +21,5 @@ HTTP_HEADER_ENCODING = 'iso-8859-1'
# Default datetime input and output formats
ISO_8601 = 'iso-8601'
default_app_config = 'rest_framework.apps.RestFrameworkConfig'

22
rest_framework/apps.py Normal file
View File

@ -0,0 +1,22 @@
from django.apps import AppConfig
from django.conf import settings
from django.utils.translation import ugettext_lazy as _
class RestFrameworkConfig(AppConfig):
name = 'rest_framework'
verbose_name = _("REST Framework")
def ready(self):
"""
Try to auto-import any modules in INSTALLED_APPS.
This lets us evaluate all of the @router.route('...') calls.
"""
from rest_framework.compat import importlib
for app in settings.INSTALLED_APPS:
for mod_name in ['api']:
try:
importlib.import_module('%s.%s' % (app, mod_name))
except ImportError:
pass

View File

@ -23,6 +23,7 @@ from django.core.exceptions import ImproperlyConfigured
from django.core.urlresolvers import NoReverseMatch
from rest_framework import exceptions, renderers, views
from rest_framework.compat import six
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
@ -53,6 +54,52 @@ def flatten(list_of_lists):
return itertools.chain(*list_of_lists)
class RouteTree:
"""
Stores a tree structure of routes.
"""
Node = namedtuple('Node', ['routes', 'value'])
def __init__(self):
self.routes = {}
def set(self, value, path, name):
"""
Set a value on a given path. Raises a KeyError if the value has already
been set or if the path doesn't exist.
"""
routes = self.routes
parent_path = []
for part in path:
if part not in routes:
on_path = (' on path %s' % parent_path) if len(parent_path) > 0 else ''
raise KeyError('Parent route "%s"%s was not registred' % (part, on_path))
parent_path.append(part)
routes = routes[part].routes
if name in routes:
on_path = (' on path %s' % path) if len(path) > 0 else ''
raise KeyError('Route "%s" already set%s' % (name, on_path))
routes[name] = self.Node({}, value)
def get(self, path, name):
"""
Get a vector of (name, value) for each entry on the path and name.
Raises KeyError if path or name is invalid.
"""
routes = self.routes
vect = []
for part in path + [name]:
node = routes[part]
vect.append((part, node.value))
routes = node.routes
return vect
class BaseRouter(object):
def __init__(self):
self.registry = []
@ -127,6 +174,32 @@ class SimpleRouter(BaseRouter):
self.trailing_slash = trailing_slash and '/' or ''
super(SimpleRouter, self).__init__()
def route(self, *full_path, **kwargs):
"""
ViewSet class decorator for automatically registering a route:
router = SimpleRouter()
@router.route('parent')
class ParentViewSet(ViewSet):
pass
@router.route('parent', 'child', base_name='children)
class ChildViewSet(ViewSet):
pass
"""
full_path = list(full_path)
assert len(full_path) > 0, 'Must provide a route prefix'
base_name = kwargs.pop('base_name', '-'.join(full_path))
def wrapper(viewset_cls):
self.register(wrapper._full_path, viewset_cls, wrapper._base_name)
return viewset_cls
wrapper._full_path = full_path
wrapper._base_name = base_name
return wrapper
def get_default_base_name(self, viewset):
"""
If `base_name` is not specified, attempt to automatically determine
@ -234,13 +307,70 @@ class SimpleRouter(BaseRouter):
lookup_value=lookup_value
)
def get_route_nodes_path(self, route_nodes):
"""
Given a vector of (prefix, viewset) tuples, get the resulting URL
regex. The URL kwarg name and lookup is determined by the prefix
and settings for each view.
Duplicate prefixes on the path will be appended with a number. For
example, the path:
[('users', UserView), ('users', UserView), ('friends', UserView)]
would result in:
/users/(?P<users_pk>[^\]+)/users/(?P<users_2_pk>[^/]+)/friends
"""
encountered = {}
parts = []
for i, node in enumerate(route_nodes):
prefix, viewset = node
if i == len(route_nodes) - 1:
parts.append(prefix)
else:
lookup_prefix = '%s_' % prefix
if prefix in encountered:
encountered[prefix] += 1
lookup_prefix = '%s_%s_' % (prefix, encountered[prefix])
else:
encountered[prefix] = 1
parts.append('%s/%s' % (prefix, self.get_lookup_regex(viewset, lookup_prefix)))
return '/'.join(parts)
def _process_registry(self):
"""
New-style routes use a list of strings for the prefix. Normalize old-
style string prefixes and then sort based on path length and prefix.
If a viewset was registered for all nodes that have nested nodes then
the result should construct a valid RouteTree when inserted in order.
"""
def normalize(entry):
prefix, viewset, base_name = entry
if isinstance(prefix, six.string_types):
prefix = [prefix]
return prefix, viewset, base_name
def sort(entry):
full_path = entry[0]
prefix, path = full_path[-1], full_path[:-1]
return len(path), prefix
return sorted(map(normalize, self.registry), key=sort)
def get_urls(self):
"""
Use the registered viewsets to generate a list of URL patterns.
"""
ret = []
urls = []
route_tree = RouteTree()
for prefix, viewset, basename in self.registry:
for full_path, viewset, basename in self._process_registry():
prefix, path = full_path[-1], full_path[:-1]
route_tree.set(viewset, path, prefix)
route_nodes = route_tree.get(path, prefix)
prefix = self.get_route_nodes_path(route_nodes)
lookup = self.get_lookup_regex(viewset)
routes = self.get_routes(viewset)
@ -260,9 +390,9 @@ class SimpleRouter(BaseRouter):
view = viewset.as_view(mapping, **route.initkwargs)
name = route.name.format(basename=basename)
ret.append(url(regex, view, name=name))
urls.append(url(regex, view, name=name))
return ret
return urls
class DefaultRouter(SimpleRouter):

View File

@ -10,7 +10,7 @@ from django.test import TestCase, override_settings
from rest_framework import permissions, serializers, viewsets
from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response
from rest_framework.routers import DefaultRouter, SimpleRouter
from rest_framework.routers import DefaultRouter, RouteTree, SimpleRouter
from rest_framework.test import APIRequestFactory
factory = APIRequestFactory()
@ -89,6 +89,53 @@ class BasicViewSet(viewsets.ViewSet):
return Response({'method': 'link2'})
class TestRouteTree(TestCase):
def setUp(self):
self.tree = RouteTree()
def test_set(self):
"""
Should set the value for the given name and path.
"""
self.tree.set(1, [], 'A')
self.tree.set(2, ['A'], 'B')
self.tree.set(3, ['A', 'B'], 'C')
root = self.tree.routes
self.assertEqual(list(root.keys()), ['A'])
self.assertEqual(root['A'].value, 1)
self.assertEqual(list(root['A'].routes.keys()), ['B'])
self.assertEqual(root['A'].routes['B'].value, 2)
self.assertEqual(list(root['A'].routes['B'].routes.keys()), ['C'])
self.assertEqual(root['A'].routes['B'].routes['C'].value, 3)
def test_set_invalid(self):
"""
A KeyError should be raised for an invalid name or path.
"""
self.tree.set(1, [], 'A')
self.tree.set(2, ['A'], 'B')
with self.assertRaises(KeyError):
self.tree.set(5, ['A', 'B', 'C', 'C'], 'E')
with self.assertRaises(KeyError):
self.tree.set(20, ['A'], 'B')
def test_get(self):
"""
Should return a list of (name, value) for each node on
the path.
"""
self.tree.set(1, [], 'A')
self.tree.set(2, ['A'], 'B')
self.tree.set(3, ['A', 'B'], 'C')
self.assertEqual(self.tree.get(['A', 'B'], 'C'), [
('A', 1), ('B', 2), ('C', 3)
])
class TestSimpleRouter(TestCase):
def setUp(self):
self.router = SimpleRouter()
@ -112,6 +159,43 @@ class TestSimpleRouter(TestCase):
for method in methods_map:
self.assertEqual(route.mapping[method], endpoint)
def test_route_nodes_path(self):
"""
Should return a full prefix for the given route_nodes. Duplicate parent
lookups should have a count appended.
"""
tree = RouteTree()
tree.set(NoteViewSet, [], 'first')
tree.set(NoteViewSet, ['first'], 'first')
tree.set(NoteViewSet, ['first', 'first'], 'last')
self.assertEqual(
self.router.get_route_nodes_path(tree.get([], 'first')),
r'first')
self.assertEqual(
self.router.get_route_nodes_path(tree.get(['first'], 'first')),
r'first/(?P<first_uuid>[^/.]+)/first')
self.assertEqual(
self.router.get_route_nodes_path(tree.get(['first', 'first'], 'last')),
r'first/(?P<first_uuid>[^/.]+)/first/(?P<first_2_uuid>[^/.]+)/last')
def test_route(self):
"""
Should allow registering ViewSets via a class decorator.
"""
@self.router.route('parent', 'child')
class ChildViewSet(NoteViewSet):
pass
@self.router.route('parent')
class ParentViewSet(NoteViewSet):
pass
self.assertEqual(self.router._process_registry(), [
(['parent'], ParentViewSet, 'parent'),
(['parent', 'child'], ChildViewSet, 'parent-child')
])
@override_settings(ROOT_URLCONF='tests.test_routers')
class TestRootView(TestCase):