diff --git a/rest_framework/authtoken/views.py b/rest_framework/authtoken/views.py index 6254d2f7f..e21856194 100644 --- a/rest_framework/authtoken/views.py +++ b/rest_framework/authtoken/views.py @@ -3,6 +3,9 @@ from rest_framework.authtoken.models import Token from rest_framework.authtoken.serializers import AuthTokenSerializer from rest_framework.response import Response from rest_framework.views import APIView +from rest_framework.schemas import ManualSchema +import coreapi +import coreschema class ObtainAuthToken(APIView): @@ -11,6 +14,29 @@ class ObtainAuthToken(APIView): parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,) renderer_classes = (renderers.JSONRenderer,) serializer_class = AuthTokenSerializer + schema = ManualSchema( + fields=[ + coreapi.Field( + name="username", + required=True, + location='form', + schema=coreschema.String( + title="Username", + description="Valid username for authentication", + ), + ), + coreapi.Field( + name="password", + required=True, + location='form', + schema=coreschema.String( + title="Password", + description="Valid password for authentication", + ), + ), + ], + encoding="application/json", + ) def post(self, request, *args, **kwargs): serializer = self.serializer_class(data=request.data, @@ -20,5 +46,4 @@ class ObtainAuthToken(APIView): token, created = Token.objects.get_or_create(user=user) return Response({'token': token.key}) - obtain_auth_token = ObtainAuthToken.as_view() diff --git a/rest_framework/schemas/inspectors.py b/rest_framework/schemas/inspectors.py index 86c6cb71a..171b88b0b 100644 --- a/rest_framework/schemas/inspectors.py +++ b/rest_framework/schemas/inspectors.py @@ -445,7 +445,7 @@ class ManualSchema(ViewInspector): Allows providing a list of coreapi.Fields, plus an optional description. """ - def __init__(self, fields, description=''): + def __init__(self, fields, description='', encoding=None): """ Parameters: @@ -455,6 +455,7 @@ class ManualSchema(ViewInspector): assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances" self._fields = fields self._description = description + self._encoding = encoding def get_link(self, path, method, base_url): @@ -464,7 +465,7 @@ class ManualSchema(ViewInspector): return coreapi.Link( url=urlparse.urljoin(base_url, path), action=method.lower(), - encoding=None, + encoding=self._encoding, fields=self._fields, description=self._description )