From 537df7a6adc7e5853481037bdd4acb352467ca11 Mon Sep 17 00:00:00 2001 From: Xavier Ordoquy Date: Tue, 7 Mar 2017 14:39:08 +0100 Subject: [PATCH] Extract APISchemaView and APIRootView out of the DefaultRouter. (#4707) --- rest_framework/routers.py | 91 +++++++++++++++++++++------------------ rest_framework/schemas.py | 50 ++++++++++++--------- 2 files changed, 78 insertions(+), 63 deletions(-) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4e3fbc4de..bdb1ab5aa 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -22,12 +22,11 @@ from collections import OrderedDict, namedtuple from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured -from rest_framework import exceptions, renderers, views +from rest_framework import views from rest_framework.compat import NoReverseMatch -from rest_framework.renderers import BrowsableAPIRenderer from rest_framework.response import Response from rest_framework.reverse import reverse -from rest_framework.schemas import SchemaGenerator +from rest_framework.schemas import SchemaGenerator, SchemaView from rest_framework.settings import api_settings from rest_framework.urlpatterns import format_suffix_patterns @@ -276,6 +275,36 @@ class SimpleRouter(BaseRouter): return ret +class APIRootView(views.APIView): + """ + The default basic root view for DefaultRouter + """ + _ignore_model_permissions = True + exclude_from_schema = True + api_root_dict = None + + def get(self, request, *args, **kwargs): + # Return a plain {"name": "hyperlink"} response. + ret = OrderedDict() + namespace = request.resolver_match.namespace + for key, url_name in self.api_root_dict.items(): + if namespace: + url_name = namespace + ':' + url_name + try: + ret[key] = reverse( + url_name, + args=args, + kwargs=kwargs, + request=request, + format=kwargs.get('format', None) + ) + except NoReverseMatch: + # Don't bail out if eg. no list routes exist, only detail routes. + continue + + return Response(ret) + + class DefaultRouter(SimpleRouter): """ The default router extends the SimpleRouter, but also adds in a default @@ -284,7 +313,9 @@ class DefaultRouter(SimpleRouter): include_root_view = True include_format_suffixes = True root_view_name = 'api-root' - default_schema_renderers = [renderers.CoreJSONRenderer, BrowsableAPIRenderer] + default_schema_renderers = None + APIRootView = APIRootView + APISchemaView = SchemaView def __init__(self, *args, **kwargs): if 'schema_title' in kwargs: @@ -300,6 +331,14 @@ class DefaultRouter(SimpleRouter): self.schema_title = kwargs.pop('schema_title', None) self.schema_url = kwargs.pop('schema_url', None) self.schema_renderers = kwargs.pop('schema_renderers', self.default_schema_renderers) + if self.default_schema_renderers: + warnings.warn( + "The 'DefaultRouter.default_schema_renderers' is pending " + "deprecation. You should override " + "'DefaultRouter.APISchemaView' instead.", + PendingDeprecationWarning + ) + if 'root_renderers' in kwargs: self.root_renderers = kwargs.pop('root_renderers') else: @@ -310,25 +349,16 @@ class DefaultRouter(SimpleRouter): """ Return a schema root view. """ - schema_renderers = self.schema_renderers schema_generator = SchemaGenerator( title=self.schema_title, url=self.schema_url, patterns=api_urls ) - class APISchemaView(views.APIView): - _ignore_model_permissions = True - exclude_from_schema = True - renderer_classes = schema_renderers - - def get(self, request, *args, **kwargs): - schema = schema_generator.get_schema(request) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - - return APISchemaView.as_view() + return self.APISchemaView.as_view( + renderer_classes=self.schema_renderers, + schema_generator=schema_generator, + ) def get_api_root_view(self, api_urls=None): """ @@ -339,32 +369,7 @@ class DefaultRouter(SimpleRouter): for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - class APIRootView(views.APIView): - _ignore_model_permissions = True - exclude_from_schema = True - - def get(self, request, *args, **kwargs): - # Return a plain {"name": "hyperlink"} response. - ret = OrderedDict() - namespace = request.resolver_match.namespace - for key, url_name in api_root_dict.items(): - if namespace: - url_name = namespace + ':' + url_name - try: - ret[key] = reverse( - url_name, - args=args, - kwargs=kwargs, - request=request, - format=kwargs.get('format', None) - ) - except NoReverseMatch: - # Don't bail out if eg. no list routes exist, only detail routes. - continue - - return Response(ret) - - return APIRootView.as_view() + return self.APIRootView.as_view(api_root_dict=api_root_dict) def get_urls(self): """ diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index bb7ad56cc..30e503b26 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -665,28 +665,38 @@ class SchemaGenerator(object): return named_path_components + [action] +class SchemaView(APIView): + _ignore_model_permissions = True + exclude_from_schema = True + renderer_classes = None + schema_generator = None + public = False + + def __init__(self, *args, **kwargs): + super(SchemaView, self).__init__(*args, **kwargs) + if self.renderer_classes is None: + if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: + self.renderer_classes = [ + renderers.CoreJSONRenderer, + renderers.BrowsableAPIRenderer, + ] + else: + self.renderer_classes = [renderers.CoreJSONRenderer] + + def get(self, request, *args, **kwargs): + schema = self.schema_generator.get_schema(request, self.public) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) + + def get_schema_view(title=None, url=None, description=None, urlconf=None, renderer_classes=None, public=False): """ Return a schema view. """ generator = SchemaGenerator(title=title, url=url, description=description, urlconf=urlconf) - if renderer_classes is None: - if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES: - rclasses = [renderers.CoreJSONRenderer, renderers.BrowsableAPIRenderer] - else: - rclasses = [renderers.CoreJSONRenderer] - else: - rclasses = renderer_classes - - class SchemaView(APIView): - _ignore_model_permissions = True - exclude_from_schema = True - renderer_classes = rclasses - - def get(self, request, *args, **kwargs): - schema = generator.get_schema(request, public) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - - return SchemaView.as_view() + return SchemaView.as_view( + renderer_classes=renderer_classes, + schema_generator=generator, + public=public, + )