From 18df4b5021e2ffc3e230d6d8f5a4d27d7d4f46ec Mon Sep 17 00:00:00 2001 From: Yann Savary Date: Thu, 5 Sep 2019 22:12:09 +0200 Subject: [PATCH] Trying to add openapi security schemes --- rest_framework/authentication.py | 51 +++++++++++++++++++++++++++++++ rest_framework/schemas/openapi.py | 43 ++++++++++++++++++++++++-- tests/schemas/test_openapi.py | 1 + 3 files changed, 93 insertions(+), 2 deletions(-) diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 1e30728d3..fddb82896 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -9,6 +9,7 @@ from django.middleware.csrf import CsrfViewMiddleware from django.utils.translation import gettext_lazy as _ from rest_framework import HTTP_HEADER_ENCODING, exceptions +from rest_framework.exceptions import AuthenticationFailed def get_authorization_header(request): @@ -204,6 +205,56 @@ class TokenAuthentication(BaseAuthentication): return self.keyword +class BearerAuthentication(BaseAuthentication): + """ + Base class for bearer authentications. + + Clients should authenticate by passing the key in the "Authorization" + HTTP header, prepended with the string "Bearer ". For example: + + Authorization: Bearer 401f7ac837da42b97f613d789819ff93537bee6a + """ + + keyword = 'Bearer' + + def authenticate_header(self, request): + return self.keyword + + def get_header(self, request): + """ + Extracts the header containing the JSON web token from the given request. + """ + header = request.META.get('HTTP_AUTHORIZATION') + + if isinstance(header, str): + # Work around django test client oddness + header = header.encode(HTTP_HEADER_ENCODING) + + return header + + def get_raw_token(self, request): + """ + Extracts an unvalidated JSON web token from the given "Authorization" header value. + """ + header = self.get_header(request) + if header is None: + return None + + parts = header.split() + + if parts[0] not in self.keyword.encode(): + # Assume the header does not contain a JSON web token + return None + + if len(parts) != 2: + raise AuthenticationFailed( + _('Authorization header must contain two space-delimited values'), + code='bad_authorization_header', + ) + + return parts[1] + + class RemoteUserAuthentication(BaseAuthentication): """ REMOTE_USER authentication. diff --git a/rest_framework/schemas/openapi.py b/rest_framework/schemas/openapi.py index ac846bf80..92efbe0ed 100644 --- a/rest_framework/schemas/openapi.py +++ b/rest_framework/schemas/openapi.py @@ -9,6 +9,9 @@ from django.db import models from django.utils.encoding import force_str from rest_framework import exceptions, serializers +from rest_framework.authentication import ( + BasicAuthentication, BearerAuthentication +) from rest_framework.compat import uritemplate from rest_framework.fields import _UnvalidatedField, empty @@ -42,8 +45,8 @@ class SchemaGenerator(BaseSchemaGenerator): return None for path, method, view in view_endpoints: - if not self.has_view_permissions(path, method, view): - continue + # if not self.has_view_permissions(path, method, view): + # continue operation = view.schema.get_operation(path, method) # Normalise path for any provided mount url. if path.startswith('/'): @@ -55,6 +58,31 @@ class SchemaGenerator(BaseSchemaGenerator): return result + def get_security_schemes(self, paths): + security_schemes = {} + for path, method in paths.items(): + for _, operation in method.items(): + for security in operation['security']: + name = next(iter(security)) + if name == 'BasicAuth': + security_schemes[name] = { + 'type': 'http', + 'scheme': 'basic' + } + elif name == 'BearerAuth': + security_schemes[name] = { + 'type': 'http', + 'scheme': 'bearer' + } + return security_schemes + + def get_components(self, paths): + components = { + 'securitySchemes': self.get_security_schemes(paths) + } + + return components + def get_schema(self, request=None, public=False): """ Generate a OpenAPI schema. @@ -69,6 +97,7 @@ class SchemaGenerator(BaseSchemaGenerator): 'openapi': '3.0.2', 'info': self.get_info(), 'paths': paths, + 'components': self.get_components(paths) } return schema @@ -102,6 +131,7 @@ class AutoSchema(ViewInspector): if request_body: operation['requestBody'] = request_body operation['responses'] = self._get_responses(path, method) + operation['security'] = self._get_security(path, method) return operation @@ -520,3 +550,12 @@ class AutoSchema(ViewInspector): 'description': "" } } + + def _get_security(self, path, method): + security = [] + for auth_class in self.view.authentication_classes: + if issubclass(auth_class, BasicAuthentication): + security.append({'BasicAuth': []}) + elif issubclass(auth_class, BearerAuthentication): + security.append({'BearerAuth': []}) + return security diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index d9375585b..e50607750 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -95,6 +95,7 @@ class TestOperationIntrospection(TestCase): }, }, }, + 'security': [{'BasicAuth': []}], } def test_path_with_id_parameter(self):