Added APIClient.authenticate()

This commit is contained in:
Tom Christie 2013-06-29 21:02:58 +01:00
parent 35022ca921
commit 664f8c6365
4 changed files with 95 additions and 8 deletions

View File

@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):
class MultiPartRenderer(BaseRenderer):
media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
format = 'form'
format = 'multipart'
charset = 'utf-8'
BOUNDARY = 'BoUnDaRyStRiNg'

View File

@ -64,6 +64,20 @@ def clone_request(request, method):
return ret
class ForcedAuthentication(object):
"""
This authentication class is used if the test client or request factory
forcibly authenticated the request.
"""
def __init__(self, force_user, force_token):
self.force_user = force_user
self.force_token = force_token
def authenticate(self, request):
return (self.force_user, self.force_token)
class Request(object):
"""
Wrapper allowing to enhance a standard `HttpRequest` instance.
@ -98,6 +112,12 @@ class Request(object):
self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
force_user = getattr(request, '_force_auth_user', None)
force_token = getattr(request, '_force_auth_token', None)
if (force_user is not None or force_token is not None):
forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,)
def _default_negotiator(self):
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()

View File

@ -5,6 +5,7 @@
from __future__ import unicode_literals
from django.conf import settings
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
from rest_framework.renderers import JSONRenderer, MultiPartRenderer
@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer
class APIRequestFactory(DjangoRequestFactory):
renderer_classes = {
'json': JSONRenderer,
'form': MultiPartRenderer
'multipart': MultiPartRenderer
}
default_format = 'form'
default_format = 'multipart'
def _encode_data(self, data, format=None, content_type=None):
"""
@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory):
return self.generic('OPTIONS', path, data, content_type, **extra)
class APIClient(APIRequestFactory, DjangoClient):
class ForceAuthClientHandler(ClientHandler):
"""
A patched version of ClientHandler that can enforce authentication
on the outgoing requests.
"""
def __init__(self, *args, **kwargs):
self._force_auth_user = None
self._force_auth_token = None
super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
def force_authenticate(self, user=None, token=None):
self._force_auth_user = user
self._force_auth_token = token
def get_response(self, request):
# This is the simplest place we can hook into to patch the
# request object.
request._force_auth_user = self._force_auth_user
request._force_auth_token = self._force_auth_token
return super(ForceAuthClientHandler, self).get_response(request)
class APIClient(APIRequestFactory, DjangoClient):
def __init__(self, enforce_csrf_checks=False, **defaults):
# Note that our super call skips Client.__init__
# since we don't need to instantiate a regular ClientHandler
super(DjangoClient, self).__init__(**defaults)
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
self.exc_info = None
self._credentials = {}
super(APIClient, self).__init__(*args, **kwargs)
def credentials(self, **kwargs):
self._credentials = kwargs
def authenticate(self, user=None, token=None):
self.handler.force_authenticate(user, token)
def get(self, path, data={}, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).get(path, data=data, **extra)

View File

@ -1,6 +1,7 @@
# -- coding: utf-8 --
from __future__ import unicode_literals
from django.contrib.auth.models import User
from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view
@ -8,10 +9,11 @@ from rest_framework.response import Response
from rest_framework.test import APIClient
@api_view(['GET'])
@api_view(['GET', 'POST'])
def mirror(request):
return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b'')
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
'user': request.user.username
})
@ -27,6 +29,40 @@ class CheckTestClient(TestCase):
self.client = APIClient()
def test_credentials(self):
"""
Setting `.credentials()` adds the required headers to each request.
"""
self.client.credentials(HTTP_AUTHORIZATION='example')
for _ in range(0, 3):
response = self.client.get('/view/')
self.assertEqual(response.data['auth'], 'example')
def test_authenticate(self):
"""
Setting `.authenticate()` forcibly authenticates each request.
"""
user = User.objects.create_user('example', 'example@example.com')
self.client.authenticate(user)
response = self.client.get('/view/')
self.assertEqual(response.data['user'], 'example')
def test_csrf_exempt_by_default(self):
"""
By default, the test client is CSRF exempt.
"""
User.objects.create_user('example', 'example@example.com', 'password')
self.client.login(username='example', password='password')
response = self.client.post('/view/')
self.assertEqual(response.status_code, 200)
def test_explicitly_enforce_csrf_checks(self):
"""
The test client can enforce CSRF checks.
"""
client = APIClient(enforce_csrf_checks=True)
User.objects.create_user('example', 'example@example.com', 'password')
client.login(username='example', password='password')
response = client.post('/view/')
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
self.assertEqual(response.status_code, 403)
self.assertEqual(response.data, expected)