Fix schema generation for ObtainAuthToken view. (#7211)

This commit is contained in:
Martin Desrumaux 2020-03-03 13:27:34 +01:00 committed by GitHub
parent 8aa8be7653
commit 609f708a27
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 55 additions and 5 deletions

View File

@ -5,11 +5,19 @@ from rest_framework import serializers
class AuthTokenSerializer(serializers.Serializer): class AuthTokenSerializer(serializers.Serializer):
username = serializers.CharField(label=_("Username")) username = serializers.CharField(
label=_("Username"),
write_only=True
)
password = serializers.CharField( password = serializers.CharField(
label=_("Password"), label=_("Password"),
style={'input_type': 'password'}, style={'input_type': 'password'},
trim_whitespace=False trim_whitespace=False,
write_only=True
)
token = serializers.CharField(
label=_("Token"),
read_only=True
) )
def validate(self, attrs): def validate(self, attrs):

View File

@ -4,6 +4,7 @@ from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.compat import coreapi, coreschema from rest_framework.compat import coreapi, coreschema
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.schemas import ManualSchema from rest_framework.schemas import ManualSchema
from rest_framework.schemas import coreapi as coreapi_schema
from rest_framework.views import APIView from rest_framework.views import APIView
@ -13,7 +14,8 @@ class ObtainAuthToken(APIView):
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
renderer_classes = (renderers.JSONRenderer,) renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer serializer_class = AuthTokenSerializer
if coreapi is not None and coreschema is not None:
if coreapi_schema.is_enabled():
schema = ManualSchema( schema = ManualSchema(
fields=[ fields=[
coreapi.Field( coreapi.Field(
@ -38,9 +40,19 @@ class ObtainAuthToken(APIView):
encoding="application/json", encoding="application/json",
) )
def get_serializer_context(self):
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
def get_serializer(self, *args, **kwargs):
kwargs['context'] = self.get_serializer_context()
return self.serializer_class(*args, **kwargs)
def post(self, request, *args, **kwargs): def post(self, request, *args, **kwargs):
serializer = self.serializer_class(data=request.data, serializer = self.get_serializer(data=request.data)
context={'request': request})
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user'] user = serializer.validated_data['user']
token, created = Token.objects.get_or_create(user=user) token, created = Token.objects.get_or_create(user=user)

View File

@ -7,6 +7,7 @@ from django.test import RequestFactory, TestCase, override_settings
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from rest_framework import filters, generics, pagination, routers, serializers from rest_framework import filters, generics, pagination, routers, serializers
from rest_framework.authtoken.views import obtain_auth_token
from rest_framework.compat import uritemplate from rest_framework.compat import uritemplate
from rest_framework.parsers import JSONParser, MultiPartParser from rest_framework.parsers import JSONParser, MultiPartParser
from rest_framework.renderers import JSONRenderer, OpenAPIRenderer from rest_framework.renderers import JSONRenderer, OpenAPIRenderer
@ -995,16 +996,45 @@ class TestGenerator(TestCase):
patterns = [ patterns = [
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()), url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
] ]
generator = SchemaGenerator(patterns=patterns) generator = SchemaGenerator(patterns=patterns)
request = create_request('/') request = create_request('/')
schema = generator.get_schema(request=request) schema = generator.get_schema(request=request)
print(schema) print(schema)
assert 'components' in schema assert 'components' in schema
assert 'schemas' in schema['components'] assert 'schemas' in schema['components']
assert 'ExampleModel' in schema['components']['schemas'] assert 'ExampleModel' in schema['components']['schemas']
def test_authtoken_serializer(self):
patterns = [
url(r'^api-token-auth/', obtain_auth_token)
]
generator = SchemaGenerator(patterns=patterns)
request = create_request('/')
schema = generator.get_schema(request=request)
print(schema)
route = schema['paths']['/api-token-auth/']['post']
body_schema = route['requestBody']['content']['application/json']['schema']
assert body_schema == {
'$ref': '#/components/schemas/AuthToken'
}
assert schema['components']['schemas']['AuthToken'] == {
'type': 'object',
'properties': {
'username': {'type': 'string', 'writeOnly': True},
'password': {'type': 'string', 'writeOnly': True},
'token': {'type': 'string', 'readOnly': True},
},
'required': ['username', 'password']
}
def test_component_name(self): def test_component_name(self):
patterns = [ patterns = [
url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()), url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()),