mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-07-27 08:29:59 +03:00
add OAS securitySchemes and security objects
This commit is contained in:
parent
2d52c9e8bc
commit
e6bbae30a8
|
@ -389,6 +389,26 @@ differentiate between request and response objects.
|
|||
By default returns `get_serializer()` but can be overridden to
|
||||
differentiate between request and response objects.
|
||||
|
||||
#### `get_security_schemes()`
|
||||
|
||||
Generates the OpenAPI `securitySchemes` components based on:
|
||||
- Your default `authentication_classes` (`settings.DEFAULT_AUTHENTICATION_CLASSES`)
|
||||
- Per-view non-default `authentication_classes`
|
||||
|
||||
These are generated using the authentication classes' `openapi_security_scheme()` class method. If you
|
||||
extend `BaseAuthentication` with your own authentication class, you can add this class method to return
|
||||
the appropriate security scheme object.
|
||||
|
||||
#### `get_security_requirements()`
|
||||
|
||||
Root-level security requirements (the top-level `security` object) are generated based on the
|
||||
default authentication classes. Operation-level security requirements are generated only if the given view's
|
||||
`authentication_classes` differ from the defaults.
|
||||
|
||||
These are generated using the authentication classes' `openapi_security_requirement()` class
|
||||
method. If you extended `BaseAuthentication` with your own authentication class, you can add this
|
||||
class method to return the appropriate list of security requirements objects.
|
||||
|
||||
### `AutoSchema.__init__()` kwargs
|
||||
|
||||
`AutoSchema` provides a number of `__init__()` kwargs that can be used for
|
||||
|
|
|
@ -49,6 +49,32 @@ class BaseAuthentication:
|
|||
"""
|
||||
pass
|
||||
|
||||
#: Name of openapi security scheme. Override if you want to customize it.
|
||||
openapi_security_scheme_name = None
|
||||
|
||||
@classmethod
|
||||
def openapi_security_scheme(cls):
|
||||
"""
|
||||
Override this to return an Open API Specification `securityScheme object
|
||||
<http://spec.openapis.org/oas/v3.0.3#security-scheme-object>`_
|
||||
"""
|
||||
return {}
|
||||
|
||||
@classmethod
|
||||
def openapi_security_requirement(cls, view, method):
|
||||
"""
|
||||
Override this to return an Open API Specification `security requirement object
|
||||
<http://spec.openapis.org/oas/v3.0.3#security-requirement-object>`_
|
||||
|
||||
:param view: used to find view attributes used by a permission class or None for root-level
|
||||
:param method: used to distinguish among method-specific permissions or None for root-level
|
||||
:return:list: [security requirement objects]
|
||||
"""
|
||||
# At this point, none of the built-in DRF authentication classes fill in the
|
||||
# requirement list: OAuth2/OIDC are the only security types that currently uses the list
|
||||
# (for scopes). See http://spec.openapis.org/oas/v3.0.3#patterned-fields-2.
|
||||
return [{}]
|
||||
|
||||
|
||||
class BasicAuthentication(BaseAuthentication):
|
||||
"""
|
||||
|
@ -108,6 +134,22 @@ class BasicAuthentication(BaseAuthentication):
|
|||
def authenticate_header(self, request):
|
||||
return 'Basic realm="%s"' % self.www_authenticate_realm
|
||||
|
||||
openapi_security_scheme_name = 'basicAuth'
|
||||
|
||||
@classmethod
|
||||
def openapi_security_scheme(cls):
|
||||
return {
|
||||
cls.openapi_security_scheme_name: {
|
||||
'type': 'http',
|
||||
'scheme': 'basic',
|
||||
'description': 'Basic Authentication'
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def openapi_security_requirement(cls, view, method):
|
||||
return [{cls.openapi_security_scheme_name: []}]
|
||||
|
||||
|
||||
class SessionAuthentication(BaseAuthentication):
|
||||
"""
|
||||
|
@ -147,6 +189,23 @@ class SessionAuthentication(BaseAuthentication):
|
|||
# CSRF failed, bail with explicit error message
|
||||
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
|
||||
|
||||
openapi_security_scheme_name = 'sessionAuth'
|
||||
|
||||
@classmethod
|
||||
def openapi_security_scheme(cls):
|
||||
return {
|
||||
cls.openapi_security_scheme_name: {
|
||||
'type': 'apiKey',
|
||||
'in': 'cookie',
|
||||
'name': 'JSESSIONID',
|
||||
'description': 'Session authentication'
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def openapi_security_requirement(cls, view, method):
|
||||
return [{cls.openapi_security_scheme_name: []}]
|
||||
|
||||
|
||||
class TokenAuthentication(BaseAuthentication):
|
||||
"""
|
||||
|
@ -210,6 +269,23 @@ class TokenAuthentication(BaseAuthentication):
|
|||
def authenticate_header(self, request):
|
||||
return self.keyword
|
||||
|
||||
openapi_security_scheme_name = 'tokenAuth'
|
||||
|
||||
@classmethod
|
||||
def openapi_security_scheme(cls):
|
||||
return {
|
||||
cls.openapi_security_scheme_name: {
|
||||
'type': 'http',
|
||||
'in': 'header',
|
||||
'name': 'Authorization', # Authorization: token ...
|
||||
'description': 'Token authentication'
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def openapi_security_requirement(cls, view, method):
|
||||
return [{cls.openapi_security_scheme_name: []}]
|
||||
|
||||
|
||||
class RemoteUserAuthentication(BaseAuthentication):
|
||||
"""
|
||||
|
@ -230,3 +306,20 @@ class RemoteUserAuthentication(BaseAuthentication):
|
|||
user = authenticate(request=request, remote_user=request.META.get(self.header))
|
||||
if user and user.is_active:
|
||||
return (user, None)
|
||||
|
||||
openapi_security_scheme_name = 'remoteUserAuth'
|
||||
|
||||
@classmethod
|
||||
def openapi_security_scheme(cls):
|
||||
return {
|
||||
cls.openapi_security_scheme_name: {
|
||||
'type': 'http',
|
||||
'in': 'header',
|
||||
'name': 'REMOTE_USER',
|
||||
'description': 'Remote User authentication'
|
||||
}
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def openapi_security_requirement(cls, view, method):
|
||||
return [{cls.openapi_security_scheme_name: []}]
|
||||
|
|
|
@ -70,6 +70,14 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
"""
|
||||
self._initialise_endpoints()
|
||||
components_schemas = {}
|
||||
security_schemes_schemas = {}
|
||||
root_security_requirements = []
|
||||
|
||||
if api_settings.DEFAULT_AUTHENTICATION_CLASSES:
|
||||
for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES:
|
||||
req = auth_class.openapi_security_requirement(None, None)
|
||||
if req:
|
||||
root_security_requirements += req
|
||||
|
||||
# Iterate endpoints generating per method path operations.
|
||||
paths = {}
|
||||
|
@ -80,6 +88,7 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
|
||||
operation = view.schema.get_operation(path, method)
|
||||
components = view.schema.get_components(path, method)
|
||||
|
||||
for k in components.keys():
|
||||
if k not in components_schemas:
|
||||
continue
|
||||
|
@ -89,6 +98,16 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
|
||||
components_schemas.update(components)
|
||||
|
||||
security_schemes = view.schema.get_security_schemes(path, method)
|
||||
for k in security_schemes.keys():
|
||||
if k not in security_schemes_schemas:
|
||||
continue
|
||||
if security_schemes_schemas[k] == security_schemes[k]:
|
||||
continue
|
||||
warnings.warn('Security scheme component "{}" has been overriden with a different '
|
||||
'value.'.format(k))
|
||||
security_schemes_schemas.update(security_schemes)
|
||||
|
||||
# Normalise path for any provided mount url.
|
||||
if path.startswith('/'):
|
||||
path = path[1:]
|
||||
|
@ -111,6 +130,14 @@ class SchemaGenerator(BaseSchemaGenerator):
|
|||
'schemas': components_schemas
|
||||
}
|
||||
|
||||
if len(security_schemes_schemas) > 0:
|
||||
if 'components' not in schema:
|
||||
schema['components'] = {}
|
||||
schema['components']['securitySchemes'] = security_schemes_schemas
|
||||
|
||||
if len(root_security_requirements) > 0:
|
||||
schema['security'] = root_security_requirements
|
||||
|
||||
return schema
|
||||
|
||||
# View Inspectors
|
||||
|
@ -146,6 +173,9 @@ class AutoSchema(ViewInspector):
|
|||
|
||||
operation['operationId'] = self.get_operation_id(path, method)
|
||||
operation['description'] = self.get_description(path, method)
|
||||
security = self.get_security_requirements(path, method)
|
||||
if security is not None:
|
||||
operation['security'] = security
|
||||
|
||||
parameters = []
|
||||
parameters += self.get_path_parameters(path, method)
|
||||
|
@ -713,6 +743,34 @@ class AutoSchema(ViewInspector):
|
|||
|
||||
return [path.split('/')[0].replace('_', '-')]
|
||||
|
||||
def get_security_schemes(self, path, method):
|
||||
"""
|
||||
Get components.schemas.securitySchemes required by this path.
|
||||
returns dict of securitySchemes.
|
||||
"""
|
||||
schemes = {}
|
||||
for auth_class in self.view.authentication_classes:
|
||||
if hasattr(auth_class, 'openapi_security_scheme'):
|
||||
schemes.update(auth_class.openapi_security_scheme())
|
||||
return schemes
|
||||
|
||||
def get_security_requirements(self, path, method):
|
||||
"""
|
||||
Get Security Requirement Object list for this operation.
|
||||
Returns a list of security requirement objects based on the view's authentication classes
|
||||
unless this view's authentication classes are the same as the root-level defaults.
|
||||
"""
|
||||
# references the securityScheme names described above in get_security_schemes()
|
||||
security = []
|
||||
if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES:
|
||||
return None
|
||||
for auth_class in self.view.authentication_classes:
|
||||
if hasattr(auth_class, 'openapi_security_requirement'):
|
||||
req = auth_class.openapi_security_requirement(self.view, method)
|
||||
if req:
|
||||
security += req
|
||||
return security
|
||||
|
||||
def _get_path_parameters(self, path, method):
|
||||
warnings.warn(
|
||||
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "
|
||||
|
|
|
@ -8,6 +8,7 @@ from django.urls import path
|
|||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from rest_framework import filters, generics, pagination, routers, serializers
|
||||
from rest_framework.authentication import TokenAuthentication
|
||||
from rest_framework.authtoken.views import obtain_auth_token
|
||||
from rest_framework.compat import uritemplate
|
||||
from rest_framework.parsers import JSONParser, MultiPartParser
|
||||
|
@ -1235,5 +1236,51 @@ class TestGenerator(TestCase):
|
|||
]
|
||||
generator = SchemaGenerator(patterns=url_patterns)
|
||||
schema = generator.get_schema(request=create_request('/'))
|
||||
assert 'components' not in schema
|
||||
assert 'schemas' not in schema['components']
|
||||
assert 'content' not in schema['paths']['/example/']['delete']['responses']['204']
|
||||
|
||||
def test_default_root_security_schemes(self):
|
||||
patterns = [
|
||||
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
|
||||
]
|
||||
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
assert 'security' in schema
|
||||
assert {'sessionAuth': []} in schema['security']
|
||||
assert {'basicAuth': []} in schema['security']
|
||||
assert 'security' not in schema['paths']['/example/']['get']
|
||||
|
||||
@override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None})
|
||||
def test_no_default_root_security_schemes(self):
|
||||
patterns = [
|
||||
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
|
||||
]
|
||||
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
assert 'security' not in schema
|
||||
|
||||
def test_operation_security_schemes(self):
|
||||
class MyExample(views.ExampleAutoSchemaComponentName):
|
||||
authentication_classes = [TokenAuthentication]
|
||||
|
||||
patterns = [
|
||||
path('^example/?$', MyExample.as_view()),
|
||||
]
|
||||
|
||||
generator = SchemaGenerator(patterns=patterns)
|
||||
|
||||
request = create_request('/')
|
||||
schema = generator.get_schema(request=request)
|
||||
assert 'security' in schema
|
||||
assert {'sessionAuth': []} in schema['security']
|
||||
assert {'basicAuth': []} in schema['security']
|
||||
get_operation = schema['paths']['/example/']['get']
|
||||
assert 'security' in get_operation
|
||||
assert {'tokenAuth': []} in get_operation['security']
|
||||
assert len(get_operation['security']) == 1
|
||||
|
|
Loading…
Reference in New Issue
Block a user