mirror of
https://github.com/encode/django-rest-framework.git
synced 2025-03-25 04:14:23 +03:00
Added APIClient.authenticate()
This commit is contained in:
parent
35022ca921
commit
664f8c6365
|
@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer):
|
||||||
|
|
||||||
class MultiPartRenderer(BaseRenderer):
|
class MultiPartRenderer(BaseRenderer):
|
||||||
media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
|
media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg'
|
||||||
format = 'form'
|
format = 'multipart'
|
||||||
charset = 'utf-8'
|
charset = 'utf-8'
|
||||||
BOUNDARY = 'BoUnDaRyStRiNg'
|
BOUNDARY = 'BoUnDaRyStRiNg'
|
||||||
|
|
||||||
|
|
|
@ -64,6 +64,20 @@ def clone_request(request, method):
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
class ForcedAuthentication(object):
|
||||||
|
"""
|
||||||
|
This authentication class is used if the test client or request factory
|
||||||
|
forcibly authenticated the request.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, force_user, force_token):
|
||||||
|
self.force_user = force_user
|
||||||
|
self.force_token = force_token
|
||||||
|
|
||||||
|
def authenticate(self, request):
|
||||||
|
return (self.force_user, self.force_token)
|
||||||
|
|
||||||
|
|
||||||
class Request(object):
|
class Request(object):
|
||||||
"""
|
"""
|
||||||
Wrapper allowing to enhance a standard `HttpRequest` instance.
|
Wrapper allowing to enhance a standard `HttpRequest` instance.
|
||||||
|
@ -98,6 +112,12 @@ class Request(object):
|
||||||
self.parser_context['request'] = self
|
self.parser_context['request'] = self
|
||||||
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
|
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
|
||||||
|
|
||||||
|
force_user = getattr(request, '_force_auth_user', None)
|
||||||
|
force_token = getattr(request, '_force_auth_token', None)
|
||||||
|
if (force_user is not None or force_token is not None):
|
||||||
|
forced_auth = ForcedAuthentication(force_user, force_token)
|
||||||
|
self.authenticators = (forced_auth,)
|
||||||
|
|
||||||
def _default_negotiator(self):
|
def _default_negotiator(self):
|
||||||
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
|
return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS()
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.test.client import Client as DjangoClient
|
from django.test.client import Client as DjangoClient
|
||||||
|
from django.test.client import ClientHandler
|
||||||
from rest_framework.compat import RequestFactory as DjangoRequestFactory
|
from rest_framework.compat import RequestFactory as DjangoRequestFactory
|
||||||
from rest_framework.compat import force_bytes_or_smart_bytes, six
|
from rest_framework.compat import force_bytes_or_smart_bytes, six
|
||||||
from rest_framework.renderers import JSONRenderer, MultiPartRenderer
|
from rest_framework.renderers import JSONRenderer, MultiPartRenderer
|
||||||
|
@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer
|
||||||
class APIRequestFactory(DjangoRequestFactory):
|
class APIRequestFactory(DjangoRequestFactory):
|
||||||
renderer_classes = {
|
renderer_classes = {
|
||||||
'json': JSONRenderer,
|
'json': JSONRenderer,
|
||||||
'form': MultiPartRenderer
|
'multipart': MultiPartRenderer
|
||||||
}
|
}
|
||||||
default_format = 'form'
|
default_format = 'multipart'
|
||||||
|
|
||||||
def _encode_data(self, data, format=None, content_type=None):
|
def _encode_data(self, data, format=None, content_type=None):
|
||||||
"""
|
"""
|
||||||
|
@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory):
|
||||||
return self.generic('OPTIONS', path, data, content_type, **extra)
|
return self.generic('OPTIONS', path, data, content_type, **extra)
|
||||||
|
|
||||||
|
|
||||||
class APIClient(APIRequestFactory, DjangoClient):
|
class ForceAuthClientHandler(ClientHandler):
|
||||||
|
"""
|
||||||
|
A patched version of ClientHandler that can enforce authentication
|
||||||
|
on the outgoing requests.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
|
self._force_auth_user = None
|
||||||
|
self._force_auth_token = None
|
||||||
|
super(ForceAuthClientHandler, self).__init__(*args, **kwargs)
|
||||||
|
|
||||||
|
def force_authenticate(self, user=None, token=None):
|
||||||
|
self._force_auth_user = user
|
||||||
|
self._force_auth_token = token
|
||||||
|
|
||||||
|
def get_response(self, request):
|
||||||
|
# This is the simplest place we can hook into to patch the
|
||||||
|
# request object.
|
||||||
|
request._force_auth_user = self._force_auth_user
|
||||||
|
request._force_auth_token = self._force_auth_token
|
||||||
|
return super(ForceAuthClientHandler, self).get_response(request)
|
||||||
|
|
||||||
|
|
||||||
|
class APIClient(APIRequestFactory, DjangoClient):
|
||||||
|
def __init__(self, enforce_csrf_checks=False, **defaults):
|
||||||
|
# Note that our super call skips Client.__init__
|
||||||
|
# since we don't need to instantiate a regular ClientHandler
|
||||||
|
super(DjangoClient, self).__init__(**defaults)
|
||||||
|
self.handler = ForceAuthClientHandler(enforce_csrf_checks)
|
||||||
|
self.exc_info = None
|
||||||
self._credentials = {}
|
self._credentials = {}
|
||||||
super(APIClient, self).__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
def credentials(self, **kwargs):
|
def credentials(self, **kwargs):
|
||||||
self._credentials = kwargs
|
self._credentials = kwargs
|
||||||
|
|
||||||
|
def authenticate(self, user=None, token=None):
|
||||||
|
self.handler.force_authenticate(user, token)
|
||||||
|
|
||||||
def get(self, path, data={}, follow=False, **extra):
|
def get(self, path, data={}, follow=False, **extra):
|
||||||
extra.update(self._credentials)
|
extra.update(self._credentials)
|
||||||
response = super(APIClient, self).get(path, data=data, **extra)
|
response = super(APIClient, self).get(path, data=data, **extra)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
# -- coding: utf-8 --
|
# -- coding: utf-8 --
|
||||||
|
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
from django.contrib.auth.models import User
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from rest_framework.compat import patterns, url
|
from rest_framework.compat import patterns, url
|
||||||
from rest_framework.decorators import api_view
|
from rest_framework.decorators import api_view
|
||||||
|
@ -8,10 +9,11 @@ from rest_framework.response import Response
|
||||||
from rest_framework.test import APIClient
|
from rest_framework.test import APIClient
|
||||||
|
|
||||||
|
|
||||||
@api_view(['GET'])
|
@api_view(['GET', 'POST'])
|
||||||
def mirror(request):
|
def mirror(request):
|
||||||
return Response({
|
return Response({
|
||||||
'auth': request.META.get('HTTP_AUTHORIZATION', b'')
|
'auth': request.META.get('HTTP_AUTHORIZATION', b''),
|
||||||
|
'user': request.user.username
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@ -27,6 +29,40 @@ class CheckTestClient(TestCase):
|
||||||
self.client = APIClient()
|
self.client = APIClient()
|
||||||
|
|
||||||
def test_credentials(self):
|
def test_credentials(self):
|
||||||
|
"""
|
||||||
|
Setting `.credentials()` adds the required headers to each request.
|
||||||
|
"""
|
||||||
self.client.credentials(HTTP_AUTHORIZATION='example')
|
self.client.credentials(HTTP_AUTHORIZATION='example')
|
||||||
|
for _ in range(0, 3):
|
||||||
|
response = self.client.get('/view/')
|
||||||
|
self.assertEqual(response.data['auth'], 'example')
|
||||||
|
|
||||||
|
def test_authenticate(self):
|
||||||
|
"""
|
||||||
|
Setting `.authenticate()` forcibly authenticates each request.
|
||||||
|
"""
|
||||||
|
user = User.objects.create_user('example', 'example@example.com')
|
||||||
|
self.client.authenticate(user)
|
||||||
response = self.client.get('/view/')
|
response = self.client.get('/view/')
|
||||||
self.assertEqual(response.data['auth'], 'example')
|
self.assertEqual(response.data['user'], 'example')
|
||||||
|
|
||||||
|
def test_csrf_exempt_by_default(self):
|
||||||
|
"""
|
||||||
|
By default, the test client is CSRF exempt.
|
||||||
|
"""
|
||||||
|
User.objects.create_user('example', 'example@example.com', 'password')
|
||||||
|
self.client.login(username='example', password='password')
|
||||||
|
response = self.client.post('/view/')
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
def test_explicitly_enforce_csrf_checks(self):
|
||||||
|
"""
|
||||||
|
The test client can enforce CSRF checks.
|
||||||
|
"""
|
||||||
|
client = APIClient(enforce_csrf_checks=True)
|
||||||
|
User.objects.create_user('example', 'example@example.com', 'password')
|
||||||
|
client.login(username='example', password='password')
|
||||||
|
response = client.post('/view/')
|
||||||
|
expected = {'detail': 'CSRF Failed: CSRF cookie not set.'}
|
||||||
|
self.assertEqual(response.status_code, 403)
|
||||||
|
self.assertEqual(response.data, expected)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user