mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-25 11:04:02 +03:00
Fix schema generation for ObtainAuthToken view. (#7211)
This commit is contained in:
parent
8aa8be7653
commit
609f708a27
|
@ -5,11 +5,19 @@ from rest_framework import serializers
|
||||||
|
|
||||||
|
|
||||||
class AuthTokenSerializer(serializers.Serializer):
|
class AuthTokenSerializer(serializers.Serializer):
|
||||||
username = serializers.CharField(label=_("Username"))
|
username = serializers.CharField(
|
||||||
|
label=_("Username"),
|
||||||
|
write_only=True
|
||||||
|
)
|
||||||
password = serializers.CharField(
|
password = serializers.CharField(
|
||||||
label=_("Password"),
|
label=_("Password"),
|
||||||
style={'input_type': 'password'},
|
style={'input_type': 'password'},
|
||||||
trim_whitespace=False
|
trim_whitespace=False,
|
||||||
|
write_only=True
|
||||||
|
)
|
||||||
|
token = serializers.CharField(
|
||||||
|
label=_("Token"),
|
||||||
|
read_only=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate(self, attrs):
|
def validate(self, attrs):
|
||||||
|
|
|
@ -4,6 +4,7 @@ from rest_framework.authtoken.serializers import AuthTokenSerializer
|
||||||
from rest_framework.compat import coreapi, coreschema
|
from rest_framework.compat import coreapi, coreschema
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.schemas import ManualSchema
|
from rest_framework.schemas import ManualSchema
|
||||||
|
from rest_framework.schemas import coreapi as coreapi_schema
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,7 +14,8 @@ class ObtainAuthToken(APIView):
|
||||||
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
|
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
|
||||||
renderer_classes = (renderers.JSONRenderer,)
|
renderer_classes = (renderers.JSONRenderer,)
|
||||||
serializer_class = AuthTokenSerializer
|
serializer_class = AuthTokenSerializer
|
||||||
if coreapi is not None and coreschema is not None:
|
|
||||||
|
if coreapi_schema.is_enabled():
|
||||||
schema = ManualSchema(
|
schema = ManualSchema(
|
||||||
fields=[
|
fields=[
|
||||||
coreapi.Field(
|
coreapi.Field(
|
||||||
|
@ -38,9 +40,19 @@ class ObtainAuthToken(APIView):
|
||||||
encoding="application/json",
|
encoding="application/json",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def get_serializer_context(self):
|
||||||
|
return {
|
||||||
|
'request': self.request,
|
||||||
|
'format': self.format_kwarg,
|
||||||
|
'view': self
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_serializer(self, *args, **kwargs):
|
||||||
|
kwargs['context'] = self.get_serializer_context()
|
||||||
|
return self.serializer_class(*args, **kwargs)
|
||||||
|
|
||||||
def post(self, request, *args, **kwargs):
|
def post(self, request, *args, **kwargs):
|
||||||
serializer = self.serializer_class(data=request.data,
|
serializer = self.get_serializer(data=request.data)
|
||||||
context={'request': request})
|
|
||||||
serializer.is_valid(raise_exception=True)
|
serializer.is_valid(raise_exception=True)
|
||||||
user = serializer.validated_data['user']
|
user = serializer.validated_data['user']
|
||||||
token, created = Token.objects.get_or_create(user=user)
|
token, created = Token.objects.get_or_create(user=user)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from django.test import RequestFactory, TestCase, override_settings
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
|
||||||
from rest_framework import filters, generics, pagination, routers, serializers
|
from rest_framework import filters, generics, pagination, routers, serializers
|
||||||
|
from rest_framework.authtoken.views import obtain_auth_token
|
||||||
from rest_framework.compat import uritemplate
|
from rest_framework.compat import uritemplate
|
||||||
from rest_framework.parsers import JSONParser, MultiPartParser
|
from rest_framework.parsers import JSONParser, MultiPartParser
|
||||||
from rest_framework.renderers import JSONRenderer, OpenAPIRenderer
|
from rest_framework.renderers import JSONRenderer, OpenAPIRenderer
|
||||||
|
@ -995,16 +996,45 @@ class TestGenerator(TestCase):
|
||||||
patterns = [
|
patterns = [
|
||||||
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
|
url(r'^example/?$', views.ExampleGenericAPIViewModel.as_view()),
|
||||||
]
|
]
|
||||||
|
|
||||||
generator = SchemaGenerator(patterns=patterns)
|
generator = SchemaGenerator(patterns=patterns)
|
||||||
|
|
||||||
request = create_request('/')
|
request = create_request('/')
|
||||||
schema = generator.get_schema(request=request)
|
schema = generator.get_schema(request=request)
|
||||||
|
|
||||||
print(schema)
|
print(schema)
|
||||||
|
|
||||||
assert 'components' in schema
|
assert 'components' in schema
|
||||||
assert 'schemas' in schema['components']
|
assert 'schemas' in schema['components']
|
||||||
assert 'ExampleModel' in schema['components']['schemas']
|
assert 'ExampleModel' in schema['components']['schemas']
|
||||||
|
|
||||||
|
def test_authtoken_serializer(self):
|
||||||
|
patterns = [
|
||||||
|
url(r'^api-token-auth/', obtain_auth_token)
|
||||||
|
]
|
||||||
|
generator = SchemaGenerator(patterns=patterns)
|
||||||
|
|
||||||
|
request = create_request('/')
|
||||||
|
schema = generator.get_schema(request=request)
|
||||||
|
|
||||||
|
print(schema)
|
||||||
|
|
||||||
|
route = schema['paths']['/api-token-auth/']['post']
|
||||||
|
body_schema = route['requestBody']['content']['application/json']['schema']
|
||||||
|
|
||||||
|
assert body_schema == {
|
||||||
|
'$ref': '#/components/schemas/AuthToken'
|
||||||
|
}
|
||||||
|
assert schema['components']['schemas']['AuthToken'] == {
|
||||||
|
'type': 'object',
|
||||||
|
'properties': {
|
||||||
|
'username': {'type': 'string', 'writeOnly': True},
|
||||||
|
'password': {'type': 'string', 'writeOnly': True},
|
||||||
|
'token': {'type': 'string', 'readOnly': True},
|
||||||
|
},
|
||||||
|
'required': ['username', 'password']
|
||||||
|
}
|
||||||
|
|
||||||
def test_component_name(self):
|
def test_component_name(self):
|
||||||
patterns = [
|
patterns = [
|
||||||
url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
|
url(r'^example/?$', views.ExampleAutoSchemaComponentName.as_view()),
|
||||||
|
|
Loading…
Reference in New Issue
Block a user