Trying to add openapi security schemes

This commit is contained in:
Yann Savary 2019-09-05 22:12:09 +02:00
parent 89ac0a1c7e
commit 18df4b5021
3 changed files with 93 additions and 2 deletions

View File

@ -9,6 +9,7 @@ from django.middleware.csrf import CsrfViewMiddleware
from django.utils.translation import gettext_lazy as _
from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.exceptions import AuthenticationFailed
def get_authorization_header(request):
@ -204,6 +205,56 @@ class TokenAuthentication(BaseAuthentication):
return self.keyword
class BearerAuthentication(BaseAuthentication):
"""
Base class for bearer authentications.
Clients should authenticate by passing the key in the "Authorization"
HTTP header, prepended with the string "Bearer ". For example:
Authorization: Bearer 401f7ac837da42b97f613d789819ff93537bee6a
"""
keyword = 'Bearer'
def authenticate_header(self, request):
return self.keyword
def get_header(self, request):
"""
Extracts the header containing the JSON web token from the given request.
"""
header = request.META.get('HTTP_AUTHORIZATION')
if isinstance(header, str):
# Work around django test client oddness
header = header.encode(HTTP_HEADER_ENCODING)
return header
def get_raw_token(self, request):
"""
Extracts an unvalidated JSON web token from the given "Authorization" header value.
"""
header = self.get_header(request)
if header is None:
return None
parts = header.split()
if parts[0] not in self.keyword.encode():
# Assume the header does not contain a JSON web token
return None
if len(parts) != 2:
raise AuthenticationFailed(
_('Authorization header must contain two space-delimited values'),
code='bad_authorization_header',
)
return parts[1]
class RemoteUserAuthentication(BaseAuthentication):
"""
REMOTE_USER authentication.

View File

@ -9,6 +9,9 @@ from django.db import models
from django.utils.encoding import force_str
from rest_framework import exceptions, serializers
from rest_framework.authentication import (
BasicAuthentication, BearerAuthentication
)
from rest_framework.compat import uritemplate
from rest_framework.fields import _UnvalidatedField, empty
@ -42,8 +45,8 @@ class SchemaGenerator(BaseSchemaGenerator):
return None
for path, method, view in view_endpoints:
if not self.has_view_permissions(path, method, view):
continue
# if not self.has_view_permissions(path, method, view):
# continue
operation = view.schema.get_operation(path, method)
# Normalise path for any provided mount url.
if path.startswith('/'):
@ -55,6 +58,31 @@ class SchemaGenerator(BaseSchemaGenerator):
return result
def get_security_schemes(self, paths):
security_schemes = {}
for path, method in paths.items():
for _, operation in method.items():
for security in operation['security']:
name = next(iter(security))
if name == 'BasicAuth':
security_schemes[name] = {
'type': 'http',
'scheme': 'basic'
}
elif name == 'BearerAuth':
security_schemes[name] = {
'type': 'http',
'scheme': 'bearer'
}
return security_schemes
def get_components(self, paths):
components = {
'securitySchemes': self.get_security_schemes(paths)
}
return components
def get_schema(self, request=None, public=False):
"""
Generate a OpenAPI schema.
@ -69,6 +97,7 @@ class SchemaGenerator(BaseSchemaGenerator):
'openapi': '3.0.2',
'info': self.get_info(),
'paths': paths,
'components': self.get_components(paths)
}
return schema
@ -102,6 +131,7 @@ class AutoSchema(ViewInspector):
if request_body:
operation['requestBody'] = request_body
operation['responses'] = self._get_responses(path, method)
operation['security'] = self._get_security(path, method)
return operation
@ -520,3 +550,12 @@ class AutoSchema(ViewInspector):
'description': ""
}
}
def _get_security(self, path, method):
security = []
for auth_class in self.view.authentication_classes:
if issubclass(auth_class, BasicAuthentication):
security.append({'BasicAuth': []})
elif issubclass(auth_class, BearerAuthentication):
security.append({'BearerAuth': []})
return security

View File

@ -95,6 +95,7 @@ class TestOperationIntrospection(TestCase):
},
},
},
'security': [{'BasicAuth': []}],
}
def test_path_with_id_parameter(self):