diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..64f110cd9 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -289,42 +289,42 @@ class DefaultRouter(SimpleRouter): self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) super(DefaultRouter, self).__init__(*args, **kwargs) + def get_schema_root_view(self, api_urls=None): + """ + Return a schema root view. + """ + schema_renderers = self.schema_renderers + schema_generator = SchemaGenerator( + title=self.schema_title, + url=self.schema_url, + patterns=api_urls + ) + + class APISchemaView(views.APIView): + _ignore_model_permissions = True + renderer_classes = schema_renderers + + def get(self, request, *args, **kwargs): + schema = schema_generator.get_schema(request) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) + + return APISchemaView.as_view() + def get_api_root_view(self, api_urls=None): """ - Return a view to use as the API root. + Return a basic root view. """ api_root_dict = OrderedDict() list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - view_renderers = list(self.root_renderers) - schema_media_types = [] - - if api_urls and self.schema_title: - view_renderers += list(self.schema_renderers) - schema_generator = SchemaGenerator( - title=self.schema_title, - url=self.schema_url, - patterns=api_urls - ) - schema_media_types = [ - renderer.media_type - for renderer in self.schema_renderers - ] - - class APIRoot(views.APIView): + class APIRootView(views.APIView): _ignore_model_permissions = True - renderer_classes = view_renderers def get(self, request, *args, **kwargs): - if request.accepted_renderer.media_type in schema_media_types: - # Return a schema response. - schema = schema_generator.get_schema(request) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - # Return a plain {"name": "hyperlink"} response. ret = OrderedDict() namespace = request.resolver_match.namespace @@ -345,7 +345,7 @@ class DefaultRouter(SimpleRouter): return Response(ret) - return APIRoot.as_view() + return APIRootView.as_view() def get_urls(self): """ @@ -355,7 +355,10 @@ class DefaultRouter(SimpleRouter): urls = super(DefaultRouter, self).get_urls() if self.include_root_view: - view = self.get_api_root_view(api_urls=urls) + if self.schema_title: + view = self.get_schema_root_view(api_urls=urls) + else: + view = self.get_api_root_view(api_urls=urls) root_url = url(r'^$', view, name=self.root_view_name) urls.append(root_url)