mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-10 19:56:59 +03:00
Added APIClient.authenticate()
This commit is contained in:
parent
35022ca921
commit
664f8c6365
|
@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):
|
|||
|
||||
class MultiPartRenderer(BaseRenderer):
|
||||
media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
|
||||
format = 'form'
|
||||
format = 'multipart'
|
||||
charset = 'utf-8'
|
||||
BOUNDARY = 'BoUnDaRyStRiNg'
|
||||
|
||||
|
|
|
@ -64,6 +64,20 @@ def clone_request(request, method):
|
|||
return ret
|
||||
|
||||
|
||||
class ForcedAuthentication(object):
|
||||
"""
|
||||
This authentication class is used if the test client or request factory
|
||||
forcibly authenticated the request.
|
||||
"""
|
||||
|
||||
def __init__(self, force_user, force_token):
|
||||
self.force_user = force_user
|
||||
self.force_token = force_token
|
||||
|
||||
def authenticate(self, request):
|
||||
return (self.force_user, self.force_token)
|
||||
|
||||
|
||||
class Request(object):
|
||||
"""
|
||||
Wrapper allowing to enhance a standard `HttpRequest` instance.
|
||||
|
@ -98,6 +112,12 @@ class Request(object):
|
|||
self.parser_context['request'] = self
|
||||
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
|
||||
|
||||
force_user = getattr(request, '_force_auth_user', None)
|
||||
force_token = getattr(request, '_force_auth_token', None)
|
||||
if (force_user is not None or force_token is not None):
|
||||
forced_auth = ForcedAuthentication(force_user, force_token)
|
||||
self.authenticators = (forced_auth,)
|
||||
|
||||
def _default_negotiator(self):
|
||||
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
|
||||
|
||||
|
|
|
@ -5,6 +5,7 @@
|
|||
from __future__ import unicode_literals
|
||||
from django.conf import settings
|
||||
from django.test.client import Client as DjangoClient
|
||||
from django.test.client import ClientHandler
|
||||
from rest_framework.compat import RequestFactory as DjangoRequestFactory
|
||||
from rest_framework.compat import force_bytes_or_smart_bytes, six
|
||||
from rest_framework.renderers import JSONRenderer, MultiPartRenderer
|
||||
|
@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer
|
|||
class APIRequestFactory(DjangoRequestFactory):
|
||||
renderer_classes = {
|
||||
'json': JSONRenderer,
|
||||
'form': MultiPartRenderer
|
||||
'multipart': MultiPartRenderer
|
||||
}
|
||||
default_format = 'form'
|
||||
default_format = 'multipart'
|
||||
|
||||
def _encode_data(self, data, format=None, content_type=None):
|
||||
"""
|
||||
|
@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory):
|
|||
return self.generic('OPTIONS', path, data, content_type, **extra)
|
||||
|
||||
|
||||
class APIClient(APIRequestFactory, DjangoClient):
|
||||
class ForceAuthClientHandler(ClientHandler):
|
||||
"""
|
||||
A patched version of ClientHandler that can enforce authentication
|
||||
on the outgoing requests.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._force_auth_user = None
|
||||
self._force_auth_token = None
|
||||
super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
|
||||
|
||||
def force_authenticate(self, user=None, token=None):
|
||||
self._force_auth_user = user
|
||||
self._force_auth_token = token
|
||||
|
||||
def get_response(self, request):
|
||||
# This is the simplest place we can hook into to patch the
|
||||
# request object.
|
||||
request._force_auth_user = self._force_auth_user
|
||||
request._force_auth_token = self._force_auth_token
|
||||
return super(ForceAuthClientHandler, self).get_response(request)
|
||||
|
||||
|
||||
class APIClient(APIRequestFactory, DjangoClient):
|
||||
def __init__(self, enforce_csrf_checks=False, **defaults):
|
||||
# Note that our super call skips Client.__init__
|
||||
# since we don't need to instantiate a regular ClientHandler
|
||||
super(DjangoClient, self).__init__(**defaults)
|
||||
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
|
||||
self.exc_info = None
|
||||
self._credentials = {}
|
||||
super(APIClient, self).__init__(*args, **kwargs)
|
||||
|
||||
def credentials(self, **kwargs):
|
||||
self._credentials = kwargs
|
||||
|
||||
def authenticate(self, user=None, token=None):
|
||||
self.handler.force_authenticate(user, token)
|
||||
|
||||
def get(self, path, data={}, follow=False, **extra):
|
||||
extra.update(self._credentials)
|
||||
response = super(APIClient, self).get(path, data=data, **extra)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# -- coding: utf-8 --
|
||||
|
||||
from __future__ import unicode_literals
|
||||
from django.contrib.auth.models import User
|
||||
from django.test import TestCase
|
||||
from rest_framework.compat import patterns, url
|
||||
from rest_framework.decorators import api_view
|
||||
|
@ -8,10 +9,11 @@ from rest_framework.response import Response
|
|||
from rest_framework.test import APIClient
|
||||
|
||||
|
||||
@api_view(['GET'])
|
||||
@api_view(['GET', 'POST'])
|
||||
def mirror(request):
|
||||
return Response({
|
||||
'auth': request.META.get('HTTP_AUTHORIZATION', b'')
|
||||
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
|
||||
'user': request.user.username
|
||||
})
|
||||
|
||||
|
||||
|
@ -27,6 +29,40 @@ class CheckTestClient(TestCase):
|
|||
self.client = APIClient()
|
||||
|
||||
def test_credentials(self):
|
||||
"""
|
||||
Setting `.credentials()` adds the required headers to each request.
|
||||
"""
|
||||
self.client.credentials(HTTP_AUTHORIZATION='example')
|
||||
for _ in range(0, 3):
|
||||
response = self.client.get('/view/')
|
||||
self.assertEqual(response.data['auth'], 'example')
|
||||
|
||||
def test_authenticate(self):
|
||||
"""
|
||||
Setting `.authenticate()` forcibly authenticates each request.
|
||||
"""
|
||||
user = User.objects.create_user('example', 'example@example.com')
|
||||
self.client.authenticate(user)
|
||||
response = self.client.get('/view/')
|
||||
self.assertEqual(response.data['user'], 'example')
|
||||
|
||||
def test_csrf_exempt_by_default(self):
|
||||
"""
|
||||
By default, the test client is CSRF exempt.
|
||||
"""
|
||||
User.objects.create_user('example', 'example@example.com', 'password')
|
||||
self.client.login(username='example', password='password')
|
||||
response = self.client.post('/view/')
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
def test_explicitly_enforce_csrf_checks(self):
|
||||
"""
|
||||
The test client can enforce CSRF checks.
|
||||
"""
|
||||
client = APIClient(enforce_csrf_checks=True)
|
||||
User.objects.create_user('example', 'example@example.com', 'password')
|
||||
client.login(username='example', password='password')
|
||||
response = client.post('/view/')
|
||||
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
|
||||
self.assertEqual(response.status_code, 403)
|
||||
self.assertEqual(response.data, expected)
|
||||
|
|
Loading…
Reference in New Issue
Block a user