diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 688deec88..28f5dc8a1 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -65,44 +65,52 @@ class SchemaGenerator(object): urls = import_module(urlconf) else: urls = urlconf - patterns = urls.urlpatterns + self.patterns = urls.urlpatterns elif patterns is None and urlconf is None: urls = import_module(settings.ROOT_URLCONF) - patterns = urls.urlpatterns + self.patterns = urls.urlpatterns + else: + self.patterns = patterns if url and not url.endswith('/'): url += '/' self.title = title self.url = url - self.endpoints = self.get_api_endpoints(patterns) + self.endpoints = None def get_schema(self, request=None): - if request is None: - endpoints = self.endpoints - else: - # Filter the list of endpoints to only include those that - # the user has permission on. - endpoints = [] - for key, link, callback in self.endpoints: - method = link.action.upper() - view = callback.cls() + if self.endpoints is None: + self.endpoints = self.get_api_endpoints(self.patterns) + + links = [] + for key, path, method, callback in self.endpoints: + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).items(): + setattr(view, attr, val) + view.args = () + view.kwargs = {} + view.format_kwarg = None + + if request is not None: view.request = clone_request(request, method) - view.format_kwarg = None try: view.check_permissions(view.request) except exceptions.APIException: - pass - else: - endpoints.append((key, link, callback)) + continue + else: + view.request = None - if not endpoints: + link = self.get_link(path, method, callback, view) + links.append((key, link)) + + if not link: return None # Generate the schema content structure, from the endpoints. # ('users', 'list'), Link -> {'users': {'list': Link()}} content = {} - for key, link, callback in endpoints: + for key, link in links: insert_into(content, key, link) # Return the schema document. @@ -122,8 +130,7 @@ class SchemaGenerator(object): if self.should_include_endpoint(path, callback): for method in self.get_allowed_methods(callback): key = self.get_key(path, method, callback) - link = self.get_link(path, method, callback) - endpoint = (key, link, callback) + endpoint = (key, path, method, callback) api_endpoints.append(endpoint) elif isinstance(pattern, RegexURLResolver): @@ -190,14 +197,10 @@ class SchemaGenerator(object): # Methods for generating each individual `Link` instance... - def get_link(self, path, method, callback): + def get_link(self, path, method, callback, view): """ Return a `coreapi.Link` instance for the given endpoint. """ - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) - fields = self.get_path_fields(path, method, callback, view) fields += self.get_serializer_fields(path, method, callback, view) fields += self.get_pagination_fields(path, method, callback, view) @@ -260,20 +263,18 @@ class SchemaGenerator(object): if method not in ('PUT', 'PATCH', 'POST'): return [] - if not hasattr(view, 'get_serializer_class'): + if not hasattr(view, 'get_serializer'): return [] - fields = [] - - serializer_class = view.get_serializer_class() - serializer = serializer_class() + serializer = view.get_serializer() if isinstance(serializer, serializers.ListSerializer): - return coreapi.Field(name='data', location='body', required=True) + return [coreapi.Field(name='data', location='body', required=True)] if not isinstance(serializer, serializers.Serializer): return [] + fields = [] for field in serializer.fields.values(): if field.read_only: continue diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 6c02c9d23..d8c0f2209 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -43,6 +43,10 @@ class ExampleViewSet(ModelViewSet): def custom_action(self, request, pk): return super(ExampleSerializer, self).retrieve(self, request) + def get_serializer(self, *args, **kwargs): + assert self.request + return super(ExampleViewSet, self).get_serializer(*args, **kwargs) + class ExampleView(APIView): permission_classes = [permissions.IsAuthenticatedOrReadOnly]