diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index bb552f3e5..63e64d668 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -5,11 +5,19 @@ from rest_framework import serializers class AuthTokenSerializer(serializers.Serializer): - username = serializers.CharField(label=_("Username")) + username = serializers.CharField( + label=_("Username"), + write_only=True + ) password = serializers.CharField( label=_("Password"), style={'input_type': 'password'}, - trim_whitespace=False + trim_whitespace=False, + write_only=True + ) + token = serializers.CharField( + label=_("Token"), + read_only=True ) def validate(self, attrs): diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index a8c751d51..50f9acbd9 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -4,6 +4,7 @@ from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.compat import coreapi, coreschema from rest_framework.response import Response from rest_framework.schemas import ManualSchema +from rest_framework.schemas import coreapi as coreapi_schema from rest_framework.views import APIView @@ -13,7 +14,8 @@ class ObtainAuthToken(APIView): parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) renderer_classes = (renderers.JSONRenderer,) serializer_class = AuthTokenSerializer - if coreapi is not None and coreschema is not None: + + if coreapi_schema.is_enabled(): schema = ManualSchema( fields=[ coreapi.Field( @@ -38,9 +40,19 @@ class ObtainAuthToken(APIView): encoding="application/json", ) + def get_serializer_context(self): + return { + 'request': self.request, + 'format': self.format_kwarg, + 'view': self + } + + def get_serializer(self, *args, **kwargs): + kwargs['context'] = self.get_serializer_context() + return self.serializer_class(*args, **kwargs) + def post(self, request, *args, **kwargs): - serializer = self.serializer_class(data=request.data, - context={'request': request}) + serializer = self.get_serializer(data=request.data) serializer.is_valid(raise_exception=True) user = serializer.validated_data['user'] token, created = Token.objects.get_or_create(user=user) diff --git a/tests/schemas/test_openapi.py b/tests/schemas/test_openapi.py index 95101403a..35d676d6c 100644 --- a/tests/schemas/test_openapi.py +++ b/tests/schemas/test_openapi.py @@ -7,6 +7,7 @@ from django.test import RequestFactory, TestCase, override_settings from django.utils.translation import gettext_lazy as _ from rest_framework import filters, generics, pagination, routers, serializers +from rest_framework.authtoken.views import obtain_auth_token from rest_framework.compat import uritemplate from rest_framework.parsers import JSONParser, MultiPartParser from rest_framework.renderers import JSONRenderer, OpenAPIRenderer @@ -995,16 +996,45 @@ class TestGenerator(TestCase): patterns = [ url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()), ] + generator = SchemaGenerator(patterns=patterns) request = create_request('/') schema = generator.get_schema(request=request) print(schema) + assert 'components' in schema assert 'schemas' in schema['components'] assert 'ExampleModel' in schema['components']['schemas'] + def test_authtoken_serializer(self): + patterns = [ + url(r'^api-token-auth/', obtain_auth_token) + ] + generator = SchemaGenerator(patterns=patterns) + + request = create_request('/') + schema = generator.get_schema(request=request) + + print(schema) + + route = schema['paths']['/api-token-auth/']['post'] + body_schema = route['requestBody']['content']['application/json']['schema'] + + assert body_schema == { + '$ref': '#/components/schemas/AuthToken' + } + assert schema['components']['schemas']['AuthToken'] == { + 'type': 'object', + 'properties': { + 'username': {'type': 'string', 'writeOnly': True}, + 'password': {'type': 'string', 'writeOnly': True}, + 'token': {'type': 'string', 'readOnly': True}, + }, + 'required': ['username', 'password'] + } + def test_component_name(self): patterns = [ url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()),