Addeded 'APITestClient.credentials()'

This commit is contained in:
Tom Christie 2013-06-29 08:05:08 +01:00
parent f585480ee1
commit 90bc07f3f1
2 changed files with 61 additions and 0 deletions

View File

@ -1,5 +1,8 @@
# -- coding: utf-8 --
# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order
# to make it harder for the user to import the wrong thing without realizing.
from __future__ import unicode_literals
from django.conf import settings
from django.test.client import Client as DjangoClient
from rest_framework.compat import RequestFactory as DjangoRequestFactory
@ -72,31 +75,57 @@ class APIRequestFactory(DjangoRequestFactory):
class APIClient(APIRequestFactory, DjangoClient):
def __init__(self, *args, **kwargs):
self._credentials = {}
super(APIClient, self).__init__(*args, **kwargs)
def credentials(self, **kwargs):
self._credentials = kwargs
def get(self, path, data={}, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).get(path, data=data, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def head(self, path, data={}, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).head(path, data=data, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def post(self, path, data=None, format=None, content_type=None, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def put(self, path, data=None, format=None, content_type=None, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def options(self, path, data=None, format=None, content_type=None, follow=False, **extra):
extra.update(self._credentials)
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)

View File

@ -0,0 +1,32 @@
# -- coding: utf-8 --
from __future__ import unicode_literals
from django.test import TestCase
from rest_framework.compat import patterns, url
from rest_framework.decorators import api_view
from rest_framework.response import Response
from rest_framework.test import APIClient
@api_view(['GET'])
def mirror(request):
return Response({
'auth': request.META.get('HTTP_AUTHORIZATION', b'')
})
urlpatterns = patterns('',
url(r'^view/$', mirror),
)
class CheckTestClient(TestCase):
urls = 'rest_framework.tests.test_testing'
def setUp(self):
self.client = APIClient()
def test_credentials(self):
self.client.credentials(HTTP_AUTHORIZATION='example')
response = self.client.get('/view/')
self.assertEqual(response.data['auth'], 'example')