mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-08-06 13:30:12 +03:00
Merge 4a32f6857f
into 1b882f7281
This commit is contained in:
commit
a0b03b744b
|
@ -21,3 +21,5 @@ HTTP_HEADER_ENCODING = 'iso-8859-1'
|
|||
|
||||
# Default datetime input and output formats
|
||||
ISO_8601 = 'iso-8601'
|
||||
|
||||
default_app_config = 'rest_framework.apps.RestFrameworkConfig'
|
||||
|
|
22
rest_framework/apps.py
Normal file
22
rest_framework/apps.py
Normal file
|
@ -0,0 +1,22 @@
|
|||
from django.apps import AppConfig
|
||||
from django.conf import settings
|
||||
from django.utils.translation import ugettext_lazy as _
|
||||
|
||||
|
||||
class RestFrameworkConfig(AppConfig):
|
||||
name = 'rest_framework'
|
||||
verbose_name = _("REST Framework")
|
||||
|
||||
def ready(self):
|
||||
"""
|
||||
Try to auto-import any modules in INSTALLED_APPS.
|
||||
|
||||
This lets us evaluate all of the @router.route('...') calls.
|
||||
"""
|
||||
from rest_framework.compat import importlib
|
||||
for app in settings.INSTALLED_APPS:
|
||||
for mod_name in ['api']:
|
||||
try:
|
||||
importlib.import_module('%s.%s' % (app, mod_name))
|
||||
except ImportError:
|
||||
pass
|
|
@ -23,6 +23,7 @@ from django.core.exceptions import ImproperlyConfigured
|
|||
from django.core.urlresolvers import NoReverseMatch
|
||||
|
||||
from rest_framework import exceptions, renderers, views
|
||||
from rest_framework.compat import six
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.reverse import reverse
|
||||
from rest_framework.schemas import SchemaGenerator
|
||||
|
@ -53,6 +54,52 @@ def flatten(list_of_lists):
|
|||
return itertools.chain(*list_of_lists)
|
||||
|
||||
|
||||
class RouteTree:
|
||||
"""
|
||||
Stores a tree structure of routes.
|
||||
"""
|
||||
Node = namedtuple('Node', ['routes', 'value'])
|
||||
|
||||
def __init__(self):
|
||||
self.routes = {}
|
||||
|
||||
def set(self, value, path, name):
|
||||
"""
|
||||
Set a value on a given path. Raises a KeyError if the value has already
|
||||
been set or if the path doesn't exist.
|
||||
"""
|
||||
routes = self.routes
|
||||
parent_path = []
|
||||
|
||||
for part in path:
|
||||
if part not in routes:
|
||||
on_path = (' on path %s' % parent_path) if len(parent_path) > 0 else ''
|
||||
raise KeyError('Parent route "%s"%s was not registred' % (part, on_path))
|
||||
parent_path.append(part)
|
||||
routes = routes[part].routes
|
||||
|
||||
if name in routes:
|
||||
on_path = (' on path %s' % path) if len(path) > 0 else ''
|
||||
raise KeyError('Route "%s" already set%s' % (name, on_path))
|
||||
|
||||
routes[name] = self.Node({}, value)
|
||||
|
||||
def get(self, path, name):
|
||||
"""
|
||||
Get a vector of (name, value) for each entry on the path and name.
|
||||
Raises KeyError if path or name is invalid.
|
||||
"""
|
||||
routes = self.routes
|
||||
vect = []
|
||||
|
||||
for part in path + [name]:
|
||||
node = routes[part]
|
||||
vect.append((part, node.value))
|
||||
routes = node.routes
|
||||
|
||||
return vect
|
||||
|
||||
|
||||
class BaseRouter(object):
|
||||
def __init__(self):
|
||||
self.registry = []
|
||||
|
@ -127,6 +174,32 @@ class SimpleRouter(BaseRouter):
|
|||
self.trailing_slash = trailing_slash and '/' or ''
|
||||
super(SimpleRouter, self).__init__()
|
||||
|
||||
def route(self, *full_path, **kwargs):
|
||||
"""
|
||||
ViewSet class decorator for automatically registering a route:
|
||||
|
||||
router = SimpleRouter()
|
||||
|
||||
@router.route('parent')
|
||||
class ParentViewSet(ViewSet):
|
||||
pass
|
||||
|
||||
@router.route('parent', 'child', base_name='children)
|
||||
class ChildViewSet(ViewSet):
|
||||
pass
|
||||
"""
|
||||
full_path = list(full_path)
|
||||
assert len(full_path) > 0, 'Must provide a route prefix'
|
||||
base_name = kwargs.pop('base_name', '-'.join(full_path))
|
||||
|
||||
def wrapper(viewset_cls):
|
||||
self.register(wrapper._full_path, viewset_cls, wrapper._base_name)
|
||||
return viewset_cls
|
||||
|
||||
wrapper._full_path = full_path
|
||||
wrapper._base_name = base_name
|
||||
return wrapper
|
||||
|
||||
def get_default_base_name(self, viewset):
|
||||
"""
|
||||
If `base_name` is not specified, attempt to automatically determine
|
||||
|
@ -234,13 +307,70 @@ class SimpleRouter(BaseRouter):
|
|||
lookup_value=lookup_value
|
||||
)
|
||||
|
||||
def get_route_nodes_path(self, route_nodes):
|
||||
"""
|
||||
Given a vector of (prefix, viewset) tuples, get the resulting URL
|
||||
regex. The URL kwarg name and lookup is determined by the prefix
|
||||
and settings for each view.
|
||||
|
||||
Duplicate prefixes on the path will be appended with a number. For
|
||||
example, the path:
|
||||
[('users', UserView), ('users', UserView), ('friends', UserView)]
|
||||
would result in:
|
||||
/users/(?P<users_pk>[^\]+)/users/(?P<users_2_pk>[^/]+)/friends
|
||||
"""
|
||||
encountered = {}
|
||||
parts = []
|
||||
|
||||
for i, node in enumerate(route_nodes):
|
||||
prefix, viewset = node
|
||||
if i == len(route_nodes) - 1:
|
||||
parts.append(prefix)
|
||||
|
||||
else:
|
||||
lookup_prefix = '%s_' % prefix
|
||||
if prefix in encountered:
|
||||
encountered[prefix] += 1
|
||||
lookup_prefix = '%s_%s_' % (prefix, encountered[prefix])
|
||||
else:
|
||||
encountered[prefix] = 1
|
||||
parts.append('%s/%s' % (prefix, self.get_lookup_regex(viewset, lookup_prefix)))
|
||||
|
||||
return '/'.join(parts)
|
||||
|
||||
def _process_registry(self):
|
||||
"""
|
||||
New-style routes use a list of strings for the prefix. Normalize old-
|
||||
style string prefixes and then sort based on path length and prefix.
|
||||
|
||||
If a viewset was registered for all nodes that have nested nodes then
|
||||
the result should construct a valid RouteTree when inserted in order.
|
||||
"""
|
||||
def normalize(entry):
|
||||
prefix, viewset, base_name = entry
|
||||
if isinstance(prefix, six.string_types):
|
||||
prefix = [prefix]
|
||||
return prefix, viewset, base_name
|
||||
|
||||
def sort(entry):
|
||||
full_path = entry[0]
|
||||
prefix, path = full_path[-1], full_path[:-1]
|
||||
return len(path), prefix
|
||||
|
||||
return sorted(map(normalize, self.registry), key=sort)
|
||||
|
||||
def get_urls(self):
|
||||
"""
|
||||
Use the registered viewsets to generate a list of URL patterns.
|
||||
"""
|
||||
ret = []
|
||||
urls = []
|
||||
route_tree = RouteTree()
|
||||
|
||||
for prefix, viewset, basename in self.registry:
|
||||
for full_path, viewset, basename in self._process_registry():
|
||||
prefix, path = full_path[-1], full_path[:-1]
|
||||
route_tree.set(viewset, path, prefix)
|
||||
route_nodes = route_tree.get(path, prefix)
|
||||
prefix = self.get_route_nodes_path(route_nodes)
|
||||
lookup = self.get_lookup_regex(viewset)
|
||||
routes = self.get_routes(viewset)
|
||||
|
||||
|
@ -260,9 +390,9 @@ class SimpleRouter(BaseRouter):
|
|||
|
||||
view = viewset.as_view(mapping, **route.initkwargs)
|
||||
name = route.name.format(basename=basename)
|
||||
ret.append(url(regex, view, name=name))
|
||||
urls.append(url(regex, view, name=name))
|
||||
|
||||
return ret
|
||||
return urls
|
||||
|
||||
|
||||
class DefaultRouter(SimpleRouter):
|
||||
|
|
|
@ -10,7 +10,7 @@ from django.test import TestCase, override_settings
|
|||
from rest_framework import permissions, serializers, viewsets
|
||||
from rest_framework.decorators import detail_route, list_route
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.routers import DefaultRouter, SimpleRouter
|
||||
from rest_framework.routers import DefaultRouter, RouteTree, SimpleRouter
|
||||
from rest_framework.test import APIRequestFactory
|
||||
|
||||
factory = APIRequestFactory()
|
||||
|
@ -89,6 +89,53 @@ class BasicViewSet(viewsets.ViewSet):
|
|||
return Response({'method': 'link2'})
|
||||
|
||||
|
||||
class TestRouteTree(TestCase):
|
||||
def setUp(self):
|
||||
self.tree = RouteTree()
|
||||
|
||||
def test_set(self):
|
||||
"""
|
||||
Should set the value for the given name and path.
|
||||
"""
|
||||
self.tree.set(1, [], 'A')
|
||||
self.tree.set(2, ['A'], 'B')
|
||||
self.tree.set(3, ['A', 'B'], 'C')
|
||||
|
||||
root = self.tree.routes
|
||||
self.assertEqual(list(root.keys()), ['A'])
|
||||
self.assertEqual(root['A'].value, 1)
|
||||
self.assertEqual(list(root['A'].routes.keys()), ['B'])
|
||||
self.assertEqual(root['A'].routes['B'].value, 2)
|
||||
self.assertEqual(list(root['A'].routes['B'].routes.keys()), ['C'])
|
||||
self.assertEqual(root['A'].routes['B'].routes['C'].value, 3)
|
||||
|
||||
def test_set_invalid(self):
|
||||
"""
|
||||
A KeyError should be raised for an invalid name or path.
|
||||
"""
|
||||
self.tree.set(1, [], 'A')
|
||||
self.tree.set(2, ['A'], 'B')
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
self.tree.set(5, ['A', 'B', 'C', 'C'], 'E')
|
||||
|
||||
with self.assertRaises(KeyError):
|
||||
self.tree.set(20, ['A'], 'B')
|
||||
|
||||
def test_get(self):
|
||||
"""
|
||||
Should return a list of (name, value) for each node on
|
||||
the path.
|
||||
"""
|
||||
self.tree.set(1, [], 'A')
|
||||
self.tree.set(2, ['A'], 'B')
|
||||
self.tree.set(3, ['A', 'B'], 'C')
|
||||
|
||||
self.assertEqual(self.tree.get(['A', 'B'], 'C'), [
|
||||
('A', 1), ('B', 2), ('C', 3)
|
||||
])
|
||||
|
||||
|
||||
class TestSimpleRouter(TestCase):
|
||||
def setUp(self):
|
||||
self.router = SimpleRouter()
|
||||
|
@ -112,6 +159,43 @@ class TestSimpleRouter(TestCase):
|
|||
for method in methods_map:
|
||||
self.assertEqual(route.mapping[method], endpoint)
|
||||
|
||||
def test_route_nodes_path(self):
|
||||
"""
|
||||
Should return a full prefix for the given route_nodes. Duplicate parent
|
||||
lookups should have a count appended.
|
||||
"""
|
||||
tree = RouteTree()
|
||||
tree.set(NoteViewSet, [], 'first')
|
||||
tree.set(NoteViewSet, ['first'], 'first')
|
||||
tree.set(NoteViewSet, ['first', 'first'], 'last')
|
||||
|
||||
self.assertEqual(
|
||||
self.router.get_route_nodes_path(tree.get([], 'first')),
|
||||
r'first')
|
||||
self.assertEqual(
|
||||
self.router.get_route_nodes_path(tree.get(['first'], 'first')),
|
||||
r'first/(?P<first_uuid>[^/.]+)/first')
|
||||
self.assertEqual(
|
||||
self.router.get_route_nodes_path(tree.get(['first', 'first'], 'last')),
|
||||
r'first/(?P<first_uuid>[^/.]+)/first/(?P<first_2_uuid>[^/.]+)/last')
|
||||
|
||||
def test_route(self):
|
||||
"""
|
||||
Should allow registering ViewSets via a class decorator.
|
||||
"""
|
||||
@self.router.route('parent', 'child')
|
||||
class ChildViewSet(NoteViewSet):
|
||||
pass
|
||||
|
||||
@self.router.route('parent')
|
||||
class ParentViewSet(NoteViewSet):
|
||||
pass
|
||||
|
||||
self.assertEqual(self.router._process_registry(), [
|
||||
(['parent'], ParentViewSet, 'parent'),
|
||||
(['parent', 'child'], ChildViewSet, 'parent-child')
|
||||
])
|
||||
|
||||
|
||||
@override_settings(ROOT_URLCONF='tests.test_routers')
|
||||
class TestRootView(TestCase):
|
||||
|
|
Loading…
Reference in New Issue
Block a user