diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 19f83ecab..31338d7c5 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -21,3 +21,5 @@ HTTP_HEADER_ENCODING = 'iso-8859-1' # Default datetime input and output formats ISO_8601 = 'iso-8601' + +default_app_config = 'rest_framework.apps.RestFrameworkConfig' diff --git a/rest_framework/apps.py b/rest_framework/apps.py new file mode 100644 index 000000000..be9bd809f --- /dev/null +++ b/rest_framework/apps.py @@ -0,0 +1,23 @@ +from importlib import import_module + +from django.apps import AppConfig +from django.conf import settings +from django.utils.translation import ugettext_lazy as _ + + +class RestFrameworkConfig(AppConfig): + name = 'rest_framework' + verbose_name = _("REST Framework") + + def ready(self): + """ + Try to auto-import any modules in INSTALLED_APPS. + + This lets us evaluate all of the @router.route('...') calls. + """ + for app in settings.INSTALLED_APPS: + for mod_name in ['api']: + try: + import_module('%s.%s' % (app, mod_name)) + except ImportError: + pass diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 7bbfc12ec..5b231a2f4 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -173,6 +173,32 @@ class SimpleRouter(BaseRouter): self.trailing_slash = trailing_slash and '/' or '' super(SimpleRouter, self).__init__() + def route(self, *full_path, **kwargs): + """ + ViewSet class decorator for automatically registering a route: + + router = SimpleRouter() + + @router.route('parent') + class ParentViewSet(ViewSet): + pass + + @router.route('parent', 'child', base_name='children) + class ChildViewSet(ViewSet): + pass + """ + full_path = list(full_path) + assert len(full_path) > 0, 'Must provide a route prefix' + base_name = kwargs.pop('base_name', '-'.join(full_path)) + + def wrapper(viewset_cls): + self.register(wrapper._full_path, viewset_cls, wrapper._base_name) + return viewset_cls + + wrapper._full_path = full_path + wrapper._base_name = base_name + return wrapper + def get_default_base_name(self, viewset): """ If `base_name` is not specified, attempt to automatically determine diff --git a/tests/test_routers.py b/tests/test_routers.py index 5d4bc692e..f0624a7a4 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -179,6 +179,23 @@ class TestSimpleRouter(TestCase): self.router.get_route_nodes_path(tree.get(['first', 'first'], 'last')), r'first/(?P[^/.]+)/first/(?P[^/.]+)/last') + def test_route(self): + """ + Should allow registering ViewSets via a class decorator. + """ + @self.router.route('parent', 'child') + class ChildViewSet(NoteViewSet): + pass + + @self.router.route('parent') + class ParentViewSet(NoteViewSet): + pass + + self.assertEqual(self.router._process_registry(), [ + (['parent'], ParentViewSet, 'parent'), + (['parent', 'child'], ChildViewSet, 'parent-child') + ]) + @override_settings(ROOT_URLCONF='tests.test_routers') class TestRootView(TestCase):