add OAS securitySchemes and security objects

This commit is contained in:
Alan Crosswell 2020-09-03 20:38:25 -04:00
parent 2d52c9e8bc
commit e6bbae30a8
No known key found for this signature in database
GPG Key ID: 55819C8ADBD81C72
4 changed files with 219 additions and 1 deletions

View File

@ -389,6 +389,26 @@ differentiate between request and response objects.
By default returns `get_serializer()` but can be overridden to By default returns `get_serializer()` but can be overridden to
differentiate between request and response objects. differentiate between request and response objects.
#### `get_security_schemes()`
Generates the OpenAPI `securitySchemes` components based on:
- Your default `authentication_classes` (`settings.DEFAULT_AUTHENTICATION_CLASSES`)
- Per-view non-default `authentication_classes`
These are generated using the authentication classes' `openapi_security_scheme()` class method. If you
extend `BaseAuthentication` with your own authentication class, you can add this class method to return
the appropriate security scheme object.
#### `get_security_requirements()`
Root-level security requirements (the top-level `security` object) are generated based on the
default authentication classes. Operation-level security requirements are generated only if the given view's
`authentication_classes` differ from the defaults.
These are generated using the authentication classes' `openapi_security_requirement()` class
method. If you extended `BaseAuthentication` with your own authentication class, you can add this
class method to return the appropriate list of security requirements objects.
### `AutoSchema.__init__()` kwargs ### `AutoSchema.__init__()` kwargs
`AutoSchema` provides a number of `__init__()` kwargs that can be used for `AutoSchema` provides a number of `__init__()` kwargs that can be used for

View File

@ -49,6 +49,32 @@ class BaseAuthentication:
""" """
pass pass
#: Name of openapi security scheme. Override if you want to customize it.
openapi_security_scheme_name = None
@classmethod
def openapi_security_scheme(cls):
"""
Override this to return an Open API Specification `securityScheme object
<http://spec.openapis.org/oas/v3.0.3#security-scheme-object>`_
"""
return {}
@classmethod
def openapi_security_requirement(cls, view, method):
"""
Override this to return an Open API Specification `security requirement object
<http://spec.openapis.org/oas/v3.0.3#security-requirement-object>`_
:param view: used to find view attributes used by a permission class or None for root-level
:param method: used to distinguish among method-specific permissions or None for root-level
:return:list: [security requirement objects]
"""
# At this point, none of the built-in DRF authentication classes fill in the
# requirement list: OAuth2/OIDC are the only security types that currently uses the list
# (for scopes). See http://spec.openapis.org/oas/v3.0.3#patterned-fields-2.
return [{}]
class BasicAuthentication(BaseAuthentication): class BasicAuthentication(BaseAuthentication):
""" """
@ -108,6 +134,22 @@ class BasicAuthentication(BaseAuthentication):
def authenticate_header(self, request): def authenticate_header(self, request):
return 'Basic realm="%s"' % self.www_authenticate_realm return 'Basic realm="%s"' % self.www_authenticate_realm
openapi_security_scheme_name = 'basicAuth'
@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'http',
'scheme': 'basic',
'description': 'Basic Authentication'
}
}
@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]
class SessionAuthentication(BaseAuthentication): class SessionAuthentication(BaseAuthentication):
""" """
@ -147,6 +189,23 @@ class SessionAuthentication(BaseAuthentication):
# CSRF failed, bail with explicit error message # CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason) raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
openapi_security_scheme_name = 'sessionAuth'
@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'apiKey',
'in': 'cookie',
'name': 'JSESSIONID',
'description': 'Session authentication'
}
}
@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]
class TokenAuthentication(BaseAuthentication): class TokenAuthentication(BaseAuthentication):
""" """
@ -210,6 +269,23 @@ class TokenAuthentication(BaseAuthentication):
def authenticate_header(self, request): def authenticate_header(self, request):
return self.keyword return self.keyword
openapi_security_scheme_name = 'tokenAuth'
@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'http',
'in': 'header',
'name': 'Authorization', # Authorization: token ...
'description': 'Token authentication'
}
}
@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]
class RemoteUserAuthentication(BaseAuthentication): class RemoteUserAuthentication(BaseAuthentication):
""" """
@ -230,3 +306,20 @@ class RemoteUserAuthentication(BaseAuthentication):
user = authenticate(request=request, remote_user=request.META.get(self.header)) user = authenticate(request=request, remote_user=request.META.get(self.header))
if user and user.is_active: if user and user.is_active:
return (user, None) return (user, None)
openapi_security_scheme_name = 'remoteUserAuth'
@classmethod
def openapi_security_scheme(cls):
return {
cls.openapi_security_scheme_name: {
'type': 'http',
'in': 'header',
'name': 'REMOTE_USER',
'description': 'Remote User authentication'
}
}
@classmethod
def openapi_security_requirement(cls, view, method):
return [{cls.openapi_security_scheme_name: []}]

View File

@ -70,6 +70,14 @@ class SchemaGenerator(BaseSchemaGenerator):
""" """
self._initialise_endpoints() self._initialise_endpoints()
components_schemas = {} components_schemas = {}
security_schemes_schemas = {}
root_security_requirements = []
if api_settings.DEFAULT_AUTHENTICATION_CLASSES:
for auth_class in api_settings.DEFAULT_AUTHENTICATION_CLASSES:
req = auth_class.openapi_security_requirement(None, None)
if req:
root_security_requirements += req
# Iterate endpoints generating per method path operations. # Iterate endpoints generating per method path operations.
paths = {} paths = {}
@ -80,6 +88,7 @@ class SchemaGenerator(BaseSchemaGenerator):
operation = view.schema.get_operation(path, method) operation = view.schema.get_operation(path, method)
components = view.schema.get_components(path, method) components = view.schema.get_components(path, method)
for k in components.keys(): for k in components.keys():
if k not in components_schemas: if k not in components_schemas:
continue continue
@ -89,6 +98,16 @@ class SchemaGenerator(BaseSchemaGenerator):
components_schemas.update(components) components_schemas.update(components)
security_schemes = view.schema.get_security_schemes(path, method)
for k in security_schemes.keys():
if k not in security_schemes_schemas:
continue
if security_schemes_schemas[k] == security_schemes[k]:
continue
warnings.warn('Security scheme component "{}" has been overriden with a different '
'value.'.format(k))
security_schemes_schemas.update(security_schemes)
# Normalise path for any provided mount url. # Normalise path for any provided mount url.
if path.startswith('/'): if path.startswith('/'):
path = path[1:] path = path[1:]
@ -111,6 +130,14 @@ class SchemaGenerator(BaseSchemaGenerator):
'schemas': components_schemas 'schemas': components_schemas
} }
if len(security_schemes_schemas) > 0:
if 'components' not in schema:
schema['components'] = {}
schema['components']['securitySchemes'] = security_schemes_schemas
if len(root_security_requirements) > 0:
schema['security'] = root_security_requirements
return schema return schema
# View Inspectors # View Inspectors
@ -146,6 +173,9 @@ class AutoSchema(ViewInspector):
operation['operationId'] = self.get_operation_id(path, method) operation['operationId'] = self.get_operation_id(path, method)
operation['description'] = self.get_description(path, method) operation['description'] = self.get_description(path, method)
security = self.get_security_requirements(path, method)
if security is not None:
operation['security'] = security
parameters = [] parameters = []
parameters += self.get_path_parameters(path, method) parameters += self.get_path_parameters(path, method)
@ -713,6 +743,34 @@ class AutoSchema(ViewInspector):
return [path.split('/')[0].replace('_', '-')] return [path.split('/')[0].replace('_', '-')]
def get_security_schemes(self, path, method):
"""
Get components.schemas.securitySchemes required by this path.
returns dict of securitySchemes.
"""
schemes = {}
for auth_class in self.view.authentication_classes:
if hasattr(auth_class, 'openapi_security_scheme'):
schemes.update(auth_class.openapi_security_scheme())
return schemes
def get_security_requirements(self, path, method):
"""
Get Security Requirement Object list for this operation.
Returns a list of security requirement objects based on the view's authentication classes
unless this view's authentication classes are the same as the root-level defaults.
"""
# references the securityScheme names described above in get_security_schemes()
security = []
if self.view.authentication_classes == api_settings.DEFAULT_AUTHENTICATION_CLASSES:
return None
for auth_class in self.view.authentication_classes:
if hasattr(auth_class, 'openapi_security_requirement'):
req = auth_class.openapi_security_requirement(self.view, method)
if req:
security += req
return security
def _get_path_parameters(self, path, method): def _get_path_parameters(self, path, method):
warnings.warn( warnings.warn(
"Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. " "Method `_get_path_parameters()` has been renamed to `get_path_parameters()`. "

View File

@ -8,6 +8,7 @@ from django.urls import path
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import filters, generics, pagination, routers, serializers from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.authentication import TokenAuthentication
from rest_framework.authtoken.views import obtain_auth_token from rest_framework.authtoken.views import obtain_auth_token
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.parsers import JSONParser, MultiPartParser from rest_framework.parsers import JSONParser, MultiPartParser
@ -1235,5 +1236,51 @@ class TestGenerator(TestCase):
] ]
generator = SchemaGenerator(patterns=url_patterns) generator = SchemaGenerator(patterns=url_patterns)
schema = generator.get_schema(request=create_request('/')) schema = generator.get_schema(request=create_request('/'))
assert 'components' not in schema assert 'schemas' not in schema['components']
assert 'content' not in schema['paths']['/example/']['delete']['responses']['204'] assert 'content' not in schema['paths']['/example/']['delete']['responses']['204']
def test_default_root_security_schemes(self):
patterns = [
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert 'security' in schema
assert {'sessionAuth': []} in schema['security']
assert {'basicAuth': []} in schema['security']
assert 'security' not in schema['paths']['/example/']['get']
@override_settings(REST_FRAMEWORK={'DEFAULT_AUTHENTICATION_CLASSES': None})
def test_no_default_root_security_schemes(self):
patterns = [
path('^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert 'security' not in schema
def test_operation_security_schemes(self):
class MyExample(views.ExampleAutoSchemaComponentName):
authentication_classes = [TokenAuthentication]
patterns = [
path('^example/?$', MyExample.as_view()),
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
assert 'security' in schema
assert {'sessionAuth': []} in schema['security']
assert {'basicAuth': []} in schema['security']
get_operation = schema['paths']['/example/']['get']
assert 'security' in get_operation
assert {'tokenAuth': []} in get_operation['security']
assert len(get_operation['security']) == 1