mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-28 17:09:59 +03:00
Trying to add openapi security schemes
This commit is contained in:
parent
89ac0a1c7e
commit
18df4b5021
|
@ -9,6 +9,7 @@ from django.middleware.csrf import CsrfViewMiddleware
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import HTTP_HEADER_ENCODING, exceptions
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
|
||||
|
||||
def get_authorization_header(request):
|
||||
|
@ -204,6 +205,56 @@ class TokenAuthentication(BaseAuthentication):
|
|||
return self.keyword
|
||||
|
||||
|
||||
class BearerAuthentication(BaseAuthentication):
|
||||
"""
|
||||
Base class for bearer authentications.
|
||||
|
||||
Clients should authenticate by passing the key in the "Authorization"
|
||||
HTTP header, prepended with the string "Bearer ". For example:
|
||||
|
||||
Authorization: Bearer 401f7ac837da42b97f613d789819ff93537bee6a
|
||||
"""
|
||||
|
||||
keyword = 'Bearer'
|
||||
|
||||
def authenticate_header(self, request):
|
||||
return self.keyword
|
||||
|
||||
def get_header(self, request):
|
||||
"""
|
||||
Extracts the header containing the JSON web token from the given request.
|
||||
"""
|
||||
header = request.META.get('HTTP_AUTHORIZATION')
|
||||
|
||||
if isinstance(header, str):
|
||||
# Work around django test client oddness
|
||||
header = header.encode(HTTP_HEADER_ENCODING)
|
||||
|
||||
return header
|
||||
|
||||
def get_raw_token(self, request):
|
||||
"""
|
||||
Extracts an unvalidated JSON web token from the given "Authorization" header value.
|
||||
"""
|
||||
header = self.get_header(request)
|
||||
if header is None:
|
||||
return None
|
||||
|
||||
parts = header.split()
|
||||
|
||||
if parts[0] not in self.keyword.encode():
|
||||
# Assume the header does not contain a JSON web token
|
||||
return None
|
||||
|
||||
if len(parts) != 2:
|
||||
raise AuthenticationFailed(
|
||||
_('Authorization header must contain two space-delimited values'),
|
||||
code='bad_authorization_header',
|
||||
)
|
||||
|
||||
return parts[1]
|
||||
|
||||
|
||||
class RemoteUserAuthentication(BaseAuthentication):
|
||||
"""
|
||||
REMOTE_USER authentication.
|
||||
|
|
|
@ -9,6 +9,9 @@ from django.db import models
|
|||
from django.utils.encoding import force_str
|
||||
|
||||
from rest_framework import exceptions, serializers
|
||||
from rest_framework.authentication import (
|
||||
BasicAuthentication, BearerAuthentication
|
||||
)
|
||||
from rest_framework.compat import uritemplate
|
||||
from rest_framework.fields import _UnvalidatedField, empty
|
||||
|
||||
|
@ -42,8 +45,8 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
return None
|
||||
|
||||
for path, method, view in view_endpoints:
|
||||
if not self.has_view_permissions(path, method, view):
|
||||
continue
|
||||
# if not self.has_view_permissions(path, method, view):
|
||||
# continue
|
||||
operation = view.schema.get_operation(path, method)
|
||||
# Normalise path for any provided mount url.
|
||||
if path.startswith('/'):
|
||||
|
@ -55,6 +58,31 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
|
||||
return result
|
||||
|
||||
def get_security_schemes(self, paths):
|
||||
security_schemes = {}
|
||||
for path, method in paths.items():
|
||||
for _, operation in method.items():
|
||||
for security in operation['security']:
|
||||
name = next(iter(security))
|
||||
if name == 'BasicAuth':
|
||||
security_schemes[name] = {
|
||||
'type': 'http',
|
||||
'scheme': 'basic'
|
||||
}
|
||||
elif name == 'BearerAuth':
|
||||
security_schemes[name] = {
|
||||
'type': 'http',
|
||||
'scheme': 'bearer'
|
||||
}
|
||||
return security_schemes
|
||||
|
||||
def get_components(self, paths):
|
||||
components = {
|
||||
'securitySchemes': self.get_security_schemes(paths)
|
||||
}
|
||||
|
||||
return components
|
||||
|
||||
def get_schema(self, request=None, public=False):
|
||||
"""
|
||||
Generate a OpenAPI schema.
|
||||
|
@ -69,6 +97,7 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
'openapi': '3.0.2',
|
||||
'info': self.get_info(),
|
||||
'paths': paths,
|
||||
'components': self.get_components(paths)
|
||||
}
|
||||
|
||||
return schema
|
||||
|
@ -102,6 +131,7 @@ class AutoSchema(ViewInspector):
|
|||
if request_body:
|
||||
operation['requestBody'] = request_body
|
||||
operation['responses'] = self._get_responses(path, method)
|
||||
operation['security'] = self._get_security(path, method)
|
||||
|
||||
return operation
|
||||
|
||||
|
@ -520,3 +550,12 @@ class AutoSchema(ViewInspector):
|
|||
'description': ""
|
||||
}
|
||||
}
|
||||
|
||||
def _get_security(self, path, method):
|
||||
security = []
|
||||
for auth_class in self.view.authentication_classes:
|
||||
if issubclass(auth_class, BasicAuthentication):
|
||||
security.append({'BasicAuth': []})
|
||||
elif issubclass(auth_class, BearerAuthentication):
|
||||
security.append({'BearerAuth': []})
|
||||
return security
|
||||
|
|
|
@ -95,6 +95,7 @@ class TestOperationIntrospection(TestCase):
|
|||
},
|
||||
},
|
||||
},
|
||||
'security': [{'BasicAuth': []}],
|
||||
}
|
||||
|
||||
def test_path_with_id_parameter(self):
|
||||
|
|
Loading…
Reference in New Issue
Block a user