Added APIClient

This commit is contained in:
Tom Christie 2013-06-28 17:50:30 +01:00
parent 7224b20d58
commit f585480ee1
3 changed files with 94 additions and 39 deletions

View File

@ -1,39 +1,54 @@
from rest_framework.compat import six, RequestFactory # Note that we use `DjangoRequestFactory` and `DjangoClient` names in order
# to make it harder for the user to import the wrong thing without realizing.
from django.conf import settings
from django.test.client import Client as DjangoClient
from rest_framework.compat import RequestFactory as DjangoRequestFactory
from rest_framework.compat import force_bytes_or_smart_bytes, six
from rest_framework.renderers import JSONRenderer, MultiPartRenderer from rest_framework.renderers import JSONRenderer, MultiPartRenderer
class APIRequestFactory(RequestFactory): class APIRequestFactory(DjangoRequestFactory):
renderer_classes = { renderer_classes = {
'json': JSONRenderer, 'json': JSONRenderer,
'form': MultiPartRenderer 'form': MultiPartRenderer
} }
default_format = 'form' default_format = 'form'
def __init__(self, format=None, **defaults): def _encode_data(self, data, format=None, content_type=None):
self.format = format or self.default_format """
super(APIRequestFactory, self).__init__(**defaults) Encode the data returning a two tuple of (bytes, content_type)
"""
def _encode_data(self, data, format, content_type):
if not data: if not data:
return ('', None) return ('', None)
format = format or self.format assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
)
if content_type is None and data is not None: if content_type:
# Content type specified explicitly, treat data as a raw bytestring
ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET)
else:
# Use format and render the data into a bytestring
format = format or self.default_format
renderer = self.renderer_classes[format]() renderer = self.renderer_classes[format]()
data = renderer.render(data) ret = renderer.render(data)
# Determine the content-type header
# Determine the content-type header from the renderer
if ';' in renderer.media_type: if ';' in renderer.media_type:
content_type = renderer.media_type content_type = renderer.media_type
else: else:
content_type = "{0}; charset={1}".format( content_type = "{0}; charset={1}".format(
renderer.media_type, renderer.charset renderer.media_type, renderer.charset
) )
# Coerce text to bytes if required.
if isinstance(data, six.text_type):
data = bytes(data.encode(renderer.charset))
return data, content_type # Coerce text to bytes if required.
if isinstance(ret, six.text_type):
ret = bytes(ret.encode(renderer.charset))
return ret, content_type
def post(self, path, data=None, format=None, content_type=None, **extra): def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
@ -46,3 +61,43 @@ class APIRequestFactory(RequestFactory):
def patch(self, path, data=None, format=None, content_type=None, **extra): def patch(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type) data, content_type = self._encode_data(data, format, content_type)
return self.generic('PATCH', path, data, content_type, **extra) return self.generic('PATCH', path, data, content_type, **extra)
def delete(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('DELETE', path, data, content_type, **extra)
def options(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra)
class APIClient(APIRequestFactory, DjangoClient):
def post(self, path, data=None, format=None, content_type=None, follow=False, **extra):
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def put(self, path, data=None, format=None, content_type=None, follow=False, **extra):
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra):
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra):
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response
def options(self, path, data=None, format=None, content_type=None, follow=False, **extra):
response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra)
if follow:
response = self._handle_redirects(response, **extra)
return response

View File

@ -1,7 +1,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import Client, TestCase from django.test import TestCase
from django.utils import unittest from django.utils import unittest
from rest_framework import HTTP_HEADER_ENCODING from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import exceptions from rest_framework import exceptions
@ -21,12 +21,11 @@ from rest_framework.authtoken.models import Token
from rest_framework.compat import patterns, url, include from rest_framework.compat import patterns, url, include
from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope
from rest_framework.compat import oauth, oauth_provider from rest_framework.compat import oauth, oauth_provider
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView from rest_framework.views import APIView
import base64 import base64
import time import time
import datetime import datetime
import json
factory = APIRequestFactory() factory = APIRequestFactory()
@ -68,7 +67,7 @@ class BasicAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
@ -87,7 +86,7 @@ class BasicAuthTests(TestCase):
credentials = ('%s:%s' % (self.username, self.password)) credentials = ('%s:%s' % (self.username, self.password))
base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials auth = 'Basic %s' % base64_credentials
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_basic_auth(self): def test_post_form_failing_basic_auth(self):
@ -97,7 +96,7 @@ class BasicAuthTests(TestCase):
def test_post_json_failing_basic_auth(self): def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails""" """Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"')
@ -107,8 +106,8 @@ class SessionAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.non_csrf_client = Client(enforce_csrf_checks=False) self.non_csrf_client = APIClient(enforce_csrf_checks=False)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
@ -154,7 +153,7 @@ class TokenAuthTests(TestCase):
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
@ -172,7 +171,7 @@ class TokenAuthTests(TestCase):
def test_post_json_passing_token_auth(self): def test_post_json_passing_token_auth(self):
"""Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF"""
auth = "Token " + self.key auth = "Token " + self.key
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
def test_post_form_failing_token_auth(self): def test_post_form_failing_token_auth(self):
@ -182,7 +181,7 @@ class TokenAuthTests(TestCase):
def test_post_json_failing_token_auth(self): def test_post_json_failing_token_auth(self):
"""Ensure POSTing json over token auth without correct credentials fails""" """Ensure POSTing json over token auth without correct credentials fails"""
response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') response = self.csrf_client.post('/token/', {'example': 'example'}, format='json')
self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED)
def test_token_has_auto_assigned_key_if_none_provided(self): def test_token_has_auto_assigned_key_if_none_provided(self):
@ -193,33 +192,33 @@ class TokenAuthTests(TestCase):
def test_token_login_json(self): def test_token_login_json(self):
"""Ensure token login view using JSON POST works.""" """Ensure token login view using JSON POST works."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': self.password}), 'application/json') {'username': self.username, 'password': self.password}, format='json')
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) self.assertEqual(response.data['token'], self.key)
def test_token_login_json_bad_creds(self): def test_token_login_json_bad_creds(self):
"""Ensure token login view using JSON POST fails if bad credentials are used.""" """Ensure token login view using JSON POST fails if bad credentials are used."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') {'username': self.username, 'password': "badpass"}, format='json')
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
def test_token_login_json_missing_fields(self): def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields.""" """Ensure token login view using JSON POST fails if missing fields."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
json.dumps({'username': self.username}), 'application/json') {'username': self.username}, format='json')
self.assertEqual(response.status_code, 400) self.assertEqual(response.status_code, 400)
def test_token_login_form(self): def test_token_login_form(self):
"""Ensure token login view using form POST works.""" """Ensure token login view using form POST works."""
client = Client(enforce_csrf_checks=True) client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/', response = client.post('/auth-token/',
{'username': self.username, 'password': self.password}) {'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) self.assertEqual(response.data['token'], self.key)
class IncorrectCredentialsTests(TestCase): class IncorrectCredentialsTests(TestCase):
@ -256,7 +255,7 @@ class OAuthTests(TestCase):
self.consts = consts self.consts = consts
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'
@ -470,12 +469,13 @@ class OAuthTests(TestCase):
response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401) self.assertEqual(response.status_code, 401)
class OAuth2Tests(TestCase): class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication""" """OAuth 2.0 authentication"""
urls = 'rest_framework.tests.test_authentication' urls = 'rest_framework.tests.test_authentication'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'

View File

@ -5,7 +5,7 @@ from __future__ import unicode_literals
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.contrib.auth import authenticate, login, logout from django.contrib.auth import authenticate, login, logout
from django.contrib.sessions.middleware import SessionMiddleware from django.contrib.sessions.middleware import SessionMiddleware
from django.test import TestCase, Client from django.test import TestCase
from rest_framework import status from rest_framework import status
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import patterns from rest_framework.compat import patterns
@ -18,7 +18,7 @@ from rest_framework.parsers import (
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory, APIClient
from rest_framework.views import APIView from rest_framework.views import APIView
from rest_framework.compat import six from rest_framework.compat import six
import json import json
@ -248,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase):
urls = 'rest_framework.tests.test_request' urls = 'rest_framework.tests.test_request'
def setUp(self): def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True) self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john' self.username = 'john'
self.email = 'lennon@thebeatles.com' self.email = 'lennon@thebeatles.com'
self.password = 'password' self.password = 'password'