mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-29 17:39:48 +03:00
Trying to add openapi security schemes
This commit is contained in:
parent
89ac0a1c7e
commit
18df4b5021
|
@ -9,6 +9,7 @@ from django.middleware.csrf import CsrfViewMiddleware
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
from rest_framework import HTTP_HEADER_ENCODING, exceptions
|
from rest_framework import HTTP_HEADER_ENCODING, exceptions
|
||||||
|
from rest_framework.exceptions import AuthenticationFailed
|
||||||
|
|
||||||
|
|
||||||
def get_authorization_header(request):
|
def get_authorization_header(request):
|
||||||
|
@ -204,6 +205,56 @@ class TokenAuthentication(BaseAuthentication):
|
||||||
return self.keyword
|
return self.keyword
|
||||||
|
|
||||||
|
|
||||||
|
class BearerAuthentication(BaseAuthentication):
|
||||||
|
"""
|
||||||
|
Base class for bearer authentications.
|
||||||
|
|
||||||
|
Clients should authenticate by passing the key in the "Authorization"
|
||||||
|
HTTP header, prepended with the string "Bearer ". For example:
|
||||||
|
|
||||||
|
Authorization: Bearer 401f7ac837da42b97f613d789819ff93537bee6a
|
||||||
|
"""
|
||||||
|
|
||||||
|
keyword = 'Bearer'
|
||||||
|
|
||||||
|
def authenticate_header(self, request):
|
||||||
|
return self.keyword
|
||||||
|
|
||||||
|
def get_header(self, request):
|
||||||
|
"""
|
||||||
|
Extracts the header containing the JSON web token from the given request.
|
||||||
|
"""
|
||||||
|
header = request.META.get('HTTP_AUTHORIZATION')
|
||||||
|
|
||||||
|
if isinstance(header, str):
|
||||||
|
# Work around django test client oddness
|
||||||
|
header = header.encode(HTTP_HEADER_ENCODING)
|
||||||
|
|
||||||
|
return header
|
||||||
|
|
||||||
|
def get_raw_token(self, request):
|
||||||
|
"""
|
||||||
|
Extracts an unvalidated JSON web token from the given "Authorization" header value.
|
||||||
|
"""
|
||||||
|
header = self.get_header(request)
|
||||||
|
if header is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
parts = header.split()
|
||||||
|
|
||||||
|
if parts[0] not in self.keyword.encode():
|
||||||
|
# Assume the header does not contain a JSON web token
|
||||||
|
return None
|
||||||
|
|
||||||
|
if len(parts) != 2:
|
||||||
|
raise AuthenticationFailed(
|
||||||
|
_('Authorization header must contain two space-delimited values'),
|
||||||
|
code='bad_authorization_header',
|
||||||
|
)
|
||||||
|
|
||||||
|
return parts[1]
|
||||||
|
|
||||||
|
|
||||||
class RemoteUserAuthentication(BaseAuthentication):
|
class RemoteUserAuthentication(BaseAuthentication):
|
||||||
"""
|
"""
|
||||||
REMOTE_USER authentication.
|
REMOTE_USER authentication.
|
||||||
|
|
|
@ -9,6 +9,9 @@ from django.db import models
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
|
||||||
from rest_framework import exceptions, serializers
|
from rest_framework import exceptions, serializers
|
||||||
|
from rest_framework.authentication import (
|
||||||
|
BasicAuthentication, BearerAuthentication
|
||||||
|
)
|
||||||
from rest_framework.compat import uritemplate
|
from rest_framework.compat import uritemplate
|
||||||
from rest_framework.fields import _UnvalidatedField, empty
|
from rest_framework.fields import _UnvalidatedField, empty
|
||||||
|
|
||||||
|
@ -42,8 +45,8 @@ class SchemaGenerator(BaseSchemaGenerator):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for path, method, view in view_endpoints:
|
for path, method, view in view_endpoints:
|
||||||
if not self.has_view_permissions(path, method, view):
|
# if not self.has_view_permissions(path, method, view):
|
||||||
continue
|
# continue
|
||||||
operation = view.schema.get_operation(path, method)
|
operation = view.schema.get_operation(path, method)
|
||||||
# Normalise path for any provided mount url.
|
# Normalise path for any provided mount url.
|
||||||
if path.startswith('/'):
|
if path.startswith('/'):
|
||||||
|
@ -55,6 +58,31 @@ class SchemaGenerator(BaseSchemaGenerator):
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def get_security_schemes(self, paths):
|
||||||
|
security_schemes = {}
|
||||||
|
for path, method in paths.items():
|
||||||
|
for _, operation in method.items():
|
||||||
|
for security in operation['security']:
|
||||||
|
name = next(iter(security))
|
||||||
|
if name == 'BasicAuth':
|
||||||
|
security_schemes[name] = {
|
||||||
|
'type': 'http',
|
||||||
|
'scheme': 'basic'
|
||||||
|
}
|
||||||
|
elif name == 'BearerAuth':
|
||||||
|
security_schemes[name] = {
|
||||||
|
'type': 'http',
|
||||||
|
'scheme': 'bearer'
|
||||||
|
}
|
||||||
|
return security_schemes
|
||||||
|
|
||||||
|
def get_components(self, paths):
|
||||||
|
components = {
|
||||||
|
'securitySchemes': self.get_security_schemes(paths)
|
||||||
|
}
|
||||||
|
|
||||||
|
return components
|
||||||
|
|
||||||
def get_schema(self, request=None, public=False):
|
def get_schema(self, request=None, public=False):
|
||||||
"""
|
"""
|
||||||
Generate a OpenAPI schema.
|
Generate a OpenAPI schema.
|
||||||
|
@ -69,6 +97,7 @@ class SchemaGenerator(BaseSchemaGenerator):
|
||||||
'openapi': '3.0.2',
|
'openapi': '3.0.2',
|
||||||
'info': self.get_info(),
|
'info': self.get_info(),
|
||||||
'paths': paths,
|
'paths': paths,
|
||||||
|
'components': self.get_components(paths)
|
||||||
}
|
}
|
||||||
|
|
||||||
return schema
|
return schema
|
||||||
|
@ -102,6 +131,7 @@ class AutoSchema(ViewInspector):
|
||||||
if request_body:
|
if request_body:
|
||||||
operation['requestBody'] = request_body
|
operation['requestBody'] = request_body
|
||||||
operation['responses'] = self._get_responses(path, method)
|
operation['responses'] = self._get_responses(path, method)
|
||||||
|
operation['security'] = self._get_security(path, method)
|
||||||
|
|
||||||
return operation
|
return operation
|
||||||
|
|
||||||
|
@ -520,3 +550,12 @@ class AutoSchema(ViewInspector):
|
||||||
'description': ""
|
'description': ""
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def _get_security(self, path, method):
|
||||||
|
security = []
|
||||||
|
for auth_class in self.view.authentication_classes:
|
||||||
|
if issubclass(auth_class, BasicAuthentication):
|
||||||
|
security.append({'BasicAuth': []})
|
||||||
|
elif issubclass(auth_class, BearerAuthentication):
|
||||||
|
security.append({'BearerAuth': []})
|
||||||
|
return security
|
||||||
|
|
|
@ -95,6 +95,7 @@ class TestOperationIntrospection(TestCase):
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
'security': [{'BasicAuth': []}],
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_path_with_id_parameter(self):
|
def test_path_with_id_parameter(self):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user