From 7224b20d58ceee22abc987980ab646ab8cb2d8dc Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 28 Jun 2013 17:17:39 +0100 Subject: [PATCH 01/12] Added APIRequestFactory --- rest_framework/compat.py | 38 +++++++++++- rest_framework/renderers.py | 11 ++++ rest_framework/response.py | 2 +- rest_framework/test.py | 48 +++++++++++++++ rest_framework/tests/test_authentication.py | 6 +- rest_framework/tests/test_decorators.py | 11 ++-- rest_framework/tests/test_filters.py | 4 +- rest_framework/tests/test_generics.py | 60 ++++++++----------- .../tests/test_hyperlinkedserializers.py | 11 ++-- rest_framework/tests/test_negotiation.py | 4 +- rest_framework/tests/test_pagination.py | 6 +- rest_framework/tests/test_permissions.py | 35 +++++------ .../tests/test_relations_hyperlink.py | 4 +- rest_framework/tests/test_renderers.py | 8 +-- rest_framework/tests/test_request.py | 15 +---- rest_framework/tests/test_reverse.py | 4 +- rest_framework/tests/test_routers.py | 5 +- rest_framework/tests/test_throttling.py | 6 +- rest_framework/tests/test_urlpatterns.py | 4 +- rest_framework/tests/test_validation.py | 8 +-- rest_framework/tests/test_views.py | 6 +- 21 files changed, 180 insertions(+), 116 deletions(-) create mode 100644 rest_framework/test.py diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cb1228465..6f7447add 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -8,6 +8,7 @@ from __future__ import unicode_literals import django from django.core.exceptions import ImproperlyConfigured +from django.conf import settings # Try to import six from Django, fallback to included `six`. try: @@ -83,7 +84,6 @@ def get_concrete_model(model_cls): # Django 1.5 add support for custom auth user model if django.VERSION >= (1, 5): - from django.conf import settings AUTH_USER_MODEL = settings.AUTH_USER_MODEL else: AUTH_USER_MODEL = 'auth.User' @@ -436,6 +436,42 @@ except ImportError: return force_text(url) +# RequestFactory only provide `generic` from 1.5 onwards + +from django.test.client import RequestFactory as DjangoRequestFactory +from django.test.client import FakePayload +try: + # In 1.5 the test client uses force_bytes + from django.utils.encoding import force_bytes_or_smart_bytes +except ImportError: + # In 1.3 and 1.4 the test client just uses smart_str + from django.utils.encoding import smart_str as force_bytes_or_smart_bytes + + +class RequestFactory(DjangoRequestFactory): + def generic(self, method, path, + data='', content_type='application/octet-stream', **extra): + parsed = urlparse.urlparse(path) + data = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + r = { + 'PATH_INFO': self._get_path(parsed), + 'QUERY_STRING': force_text(parsed[4]), + 'REQUEST_METHOD': str(method), + } + if data: + r.update({ + 'CONTENT_LENGTH': len(data), + 'CONTENT_TYPE': str(content_type), + 'wsgi.input': FakePayload(data), + }) + elif django.VERSION <= (1, 4): + # For 1.3 we need an empty WSGI payload + r.update({ + 'wsgi.input': FakePayload('') + }) + r.update(extra) + return self.request(**r) + # Markdown is optional try: import markdown diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 8b2428ad8..d7a7ef297 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -14,6 +14,7 @@ from django import forms from django.core.exceptions import ImproperlyConfigured from django.http.multipartparser import parse_header from django.template import RequestContext, loader, Template +from django.test.client import encode_multipart from django.utils.xmlutils import SimplerXMLGenerator from rest_framework.compat import StringIO from rest_framework.compat import six @@ -571,3 +572,13 @@ class BrowsableAPIRenderer(BaseRenderer): response.status_code = status.HTTP_200_OK return ret + + +class MultiPartRenderer(BaseRenderer): + media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' + format = 'form' + charset = 'utf-8' + BOUNDARY = 'BoUnDaRyStRiNg' + + def render(self, data, accepted_media_type=None, renderer_context=None): + return encode_multipart(self.BOUNDARY, data) diff --git a/rest_framework/response.py b/rest_framework/response.py index 5877c8a3e..c4b2aaa66 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -50,7 +50,7 @@ class Response(SimpleTemplateResponse): charset = renderer.charset content_type = self.content_type - if content_type is None and charset is not None: + if content_type is None and charset is not None and ';' not in media_type: content_type = "{0}; charset={1}".format(media_type, charset) elif content_type is None: content_type = media_type diff --git a/rest_framework/test.py b/rest_framework/test.py new file mode 100644 index 000000000..92281cafc --- /dev/null +++ b/rest_framework/test.py @@ -0,0 +1,48 @@ +from rest_framework.compat import six, RequestFactory +from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +class APIRequestFactory(RequestFactory): + renderer_classes = { + 'json': JSONRenderer, + 'form': MultiPartRenderer + } + default_format = 'form' + + def __init__(self, format=None, **defaults): + self.format = format or self.default_format + super(APIRequestFactory, self).__init__(**defaults) + + def _encode_data(self, data, format, content_type): + if not data: + return ('', None) + + format = format or self.format + + if content_type is None and data is not None: + renderer = self.renderer_classes[format]() + data = renderer.render(data) + # Determine the content-type header + if ';' in renderer.media_type: + content_type = renderer.media_type + else: + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) + # Coerce text to bytes if required. + if isinstance(data, six.text_type): + data = bytes(data.encode(renderer.charset)) + + return data, content_type + + def post(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('POST', path, data, content_type, **extra) + + def put(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PUT', path, data, content_type, **extra) + + def patch(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('PATCH', path, data, content_type, **extra) diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index 6a50be064..f2c51c68f 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -21,14 +21,14 @@ from rest_framework.authtoken.models import Token from rest_framework.compat import patterns, url, include from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -import json import base64 import time import datetime +import json -factory = RequestFactory() +factory = APIRequestFactory() class MockView(APIView): diff --git a/rest_framework/tests/test_decorators.py b/rest_framework/tests/test_decorators.py index 1016fed3f..195f0ba3e 100644 --- a/rest_framework/tests/test_decorators.py +++ b/rest_framework/tests/test_decorators.py @@ -1,12 +1,13 @@ from __future__ import unicode_literals from django.test import TestCase from rest_framework import status +from rest_framework.authentication import BasicAuthentication +from rest_framework.parsers import JSONParser +from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response from rest_framework.renderers import JSONRenderer -from rest_framework.parsers import JSONParser -from rest_framework.authentication import BasicAuthentication +from rest_framework.test import APIRequestFactory from rest_framework.throttling import UserRateThrottle -from rest_framework.permissions import IsAuthenticated from rest_framework.views import APIView from rest_framework.decorators import ( api_view, @@ -17,13 +18,11 @@ from rest_framework.decorators import ( permission_classes, ) -from rest_framework.tests.utils import RequestFactory - class DecoratorTestCase(TestCase): def setUp(self): - self.factory = RequestFactory() + self.factory = APIRequestFactory() def _finalize_response(self, request, response, *args, **kwargs): response.request = request diff --git a/rest_framework/tests/test_filters.py b/rest_framework/tests/test_filters.py index aaed62478..c9d9e7ffa 100644 --- a/rest_framework/tests/test_filters.py +++ b/rest_framework/tests/test_filters.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.urlresolvers import reverse from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, serializers, status, filters from rest_framework.compat import django_filters, patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): diff --git a/rest_framework/tests/test_generics.py b/rest_framework/tests/test_generics.py index 37734195a..1550880b5 100644 --- a/rest_framework/tests/test_generics.py +++ b/rest_framework/tests/test_generics.py @@ -3,12 +3,11 @@ from django.db import models from django.shortcuts import get_object_or_404 from django.test import TestCase from rest_framework import generics, renderers, serializers, status -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel, Comment, SlugBasedModel from rest_framework.compat import six -import json -factory = RequestFactory() +factory = APIRequestFactory() class RootView(generics.ListCreateAPIView): @@ -71,9 +70,8 @@ class TestRootView(TestCase): """ POST requests to ListCreateAPIView should create a new object. """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -85,9 +83,8 @@ class TestRootView(TestCase): """ PUT requests to ListCreateAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.put('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -148,9 +145,8 @@ class TestRootView(TestCase): """ POST requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(1): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -189,9 +185,8 @@ class TestInstanceView(TestCase): """ POST requests to RetrieveUpdateDestroyAPIView should not be allowed """ - content = {'text': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.post('/', data, format='json') with self.assertNumQueries(0): response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_405_METHOD_NOT_ALLOWED) @@ -201,9 +196,8 @@ class TestInstanceView(TestCase): """ PUT requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk='1').render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -215,9 +209,8 @@ class TestInstanceView(TestCase): """ PATCH requests to RetrieveUpdateDestroyAPIView should update an object. """ - content = {'text': 'foobar'} - request = factory.patch('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.patch('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() @@ -293,9 +286,8 @@ class TestInstanceView(TestCase): """ PUT requests to create a new object should not be able to set the id. """ - content = {'id': 999, 'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'id': 999, 'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(2): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -309,9 +301,8 @@ class TestInstanceView(TestCase): if it does not currently exist. """ self.objects.get(id=1).delete() - content = {'text': 'foobar'} - request = factory.put('/1', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/1', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=1).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -324,10 +315,9 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if it doesn't exist. """ - content = {'text': 'foobar'} + data = {'text': 'foobar'} # pk fields can not be created on demand, only the database can set the pk for a new object - request = factory.put('/5', json.dumps(content), - content_type='application/json') + request = factory.put('/5', data, format='json') with self.assertNumQueries(3): response = self.view(request, pk=5).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -339,9 +329,8 @@ class TestInstanceView(TestCase): PUT requests to RetrieveUpdateDestroyAPIView should create an object at the requested url if possible, else return HTTP_403_FORBIDDEN error-response. """ - content = {'text': 'foobar'} - request = factory.put('/test_slug', json.dumps(content), - content_type='application/json') + data = {'text': 'foobar'} + request = factory.put('/test_slug', data, format='json') with self.assertNumQueries(2): response = self.slug_based_view(request, slug='test_slug').render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) @@ -415,9 +404,8 @@ class TestCreateModelWithAutoNowAddField(TestCase): https://github.com/tomchristie/django-rest-framework/issues/285 """ - content = {'email': 'foobar@example.com', 'content': 'foobar'} - request = factory.post('/', json.dumps(content), - content_type='application/json') + data = {'email': 'foobar@example.com', 'content': 'foobar'} + request = factory.post('/', data, format='json') response = self.view(request).render() self.assertEqual(response.status_code, status.HTTP_201_CREATED) created = self.objects.get(id=1) diff --git a/rest_framework/tests/test_hyperlinkedserializers.py b/rest_framework/tests/test_hyperlinkedserializers.py index 129600cb4..61e613d75 100644 --- a/rest_framework/tests/test_hyperlinkedserializers.py +++ b/rest_framework/tests/test_hyperlinkedserializers.py @@ -1,12 +1,15 @@ from __future__ import unicode_literals import json from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import generics, status, serializers from rest_framework.compat import patterns, url -from rest_framework.tests.models import Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, Album, Photo, OptionalRelationModel +from rest_framework.test import APIRequestFactory +from rest_framework.tests.models import ( + Anchor, BasicModel, ManyToManyModel, BlogPost, BlogPostComment, + Album, Photo, OptionalRelationModel +) -factory = RequestFactory() +factory = APIRequestFactory() class BlogPostCommentSerializer(serializers.ModelSerializer): @@ -21,7 +24,7 @@ class BlogPostCommentSerializer(serializers.ModelSerializer): class PhotoSerializer(serializers.Serializer): description = serializers.CharField() - album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), slug_field='title', slug_url_kwarg='title') + album_url = serializers.HyperlinkedRelatedField(source='album', view_name='album-detail', queryset=Album.objects.all(), lookup_field='title', slug_url_kwarg='title') def restore_object(self, attrs, instance=None): return Photo(**attrs) diff --git a/rest_framework/tests/test_negotiation.py b/rest_framework/tests/test_negotiation.py index 7f84827f0..04b89eb60 100644 --- a/rest_framework/tests/test_negotiation.py +++ b/rest_framework/tests/test_negotiation.py @@ -1,12 +1,12 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.negotiation import DefaultContentNegotiation from rest_framework.request import Request from rest_framework.renderers import BaseRenderer +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() class MockJSONRenderer(BaseRenderer): diff --git a/rest_framework/tests/test_pagination.py b/rest_framework/tests/test_pagination.py index e538a78e5..85d4640ea 100644 --- a/rest_framework/tests/test_pagination.py +++ b/rest_framework/tests/test_pagination.py @@ -4,13 +4,13 @@ from decimal import Decimal from django.db import models from django.core.paginator import Paginator from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from rest_framework import generics, status, pagination, filters, serializers from rest_framework.compat import django_filters +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import BasicModel -factory = RequestFactory() +factory = APIRequestFactory() class FilterableItem(models.Model): @@ -369,7 +369,7 @@ class TestCustomPaginationSerializer(TestCase): self.page = paginator.page(1) def test_custom_pagination_serializer(self): - request = RequestFactory().get('/foobar') + request = APIRequestFactory().get('/foobar') serializer = CustomPaginationSerializer( instance=self.page, context={'request': request} diff --git a/rest_framework/tests/test_permissions.py b/rest_framework/tests/test_permissions.py index 6caaf65b0..e2cca3808 100644 --- a/rest_framework/tests/test_permissions.py +++ b/rest_framework/tests/test_permissions.py @@ -3,11 +3,10 @@ from django.contrib.auth.models import User, Permission from django.db import models from django.test import TestCase from rest_framework import generics, status, permissions, authentication, HTTP_HEADER_ENCODING -from rest_framework.tests.utils import RequestFactory +from rest_framework.test import APIRequestFactory import base64 -import json -factory = RequestFactory() +factory = APIRequestFactory() class BasicModel(models.Model): @@ -56,15 +55,13 @@ class ModelPermissionsIntegrationTests(TestCase): BasicModel(text='foo').save() def test_has_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_201_CREATED) def test_has_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -75,15 +72,13 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(response.status_code, status.HTTP_204_NO_CONTENT) def test_does_not_have_create_permissions(self): - request = factory.post('/', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.post('/', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk=1) self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_does_not_have_put_permissions(self): - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) @@ -95,28 +90,26 @@ class ModelPermissionsIntegrationTests(TestCase): def test_has_put_as_create_permissions(self): # User only has update permissions - should be able to update an entity. - request = factory.put('/1', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/1', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) # But if PUTing to a new entity, permission should be denied. - request = factory.put('/2', json.dumps({'text': 'foobar'}), - content_type='application/json', + request = factory.put('/2', {'text': 'foobar'}, format='json', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='2') self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) def test_options_permitted(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.permitted_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertIn('actions', response.data) self.assertEqual(list(response.data['actions'].keys()), ['POST']) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.permitted_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) @@ -124,26 +117,26 @@ class ModelPermissionsIntegrationTests(TestCase): self.assertEqual(list(response.data['actions'].keys()), ['PUT']) def test_options_disallowed(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.disallowed_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.disallowed_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) def test_options_updateonly(self): - request = factory.options('/', content_type='application/json', + request = factory.options('/', HTTP_AUTHORIZATION=self.updateonly_credentials) response = root_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertNotIn('actions', response.data) - request = factory.options('/1', content_type='application/json', + request = factory.options('/1', HTTP_AUTHORIZATION=self.updateonly_credentials) response = instance_view(request, pk='1') self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_relations_hyperlink.py b/rest_framework/tests/test_relations_hyperlink.py index 2ca7f4f2b..3c4d39af6 100644 --- a/rest_framework/tests/test_relations_hyperlink.py +++ b/rest_framework/tests/test_relations_hyperlink.py @@ -1,15 +1,15 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework import serializers from rest_framework.compat import patterns, url +from rest_framework.test import APIRequestFactory from rest_framework.tests.models import ( BlogPost, ManyToManyTarget, ManyToManySource, ForeignKeyTarget, ForeignKeySource, NullableForeignKeySource, OneToOneTarget, NullableOneToOneSource ) -factory = RequestFactory() +factory = APIRequestFactory() request = factory.get('/') # Just to ensure we have a request in the serializer context diff --git a/rest_framework/tests/test_renderers.py b/rest_framework/tests/test_renderers.py index 95b597411..df6f4aa63 100644 --- a/rest_framework/tests/test_renderers.py +++ b/rest_framework/tests/test_renderers.py @@ -4,19 +4,17 @@ from __future__ import unicode_literals from decimal import Decimal from django.core.cache import cache from django.test import TestCase -from django.test.client import RequestFactory from django.utils import unittest from django.utils.translation import ugettext_lazy as _ from rest_framework import status, permissions -from rest_framework.compat import yaml, etree, patterns, url, include +from rest_framework.compat import yaml, etree, patterns, url, include, six, StringIO from rest_framework.response import Response from rest_framework.views import APIView from rest_framework.renderers import BaseRenderer, JSONRenderer, YAMLRenderer, \ XMLRenderer, JSONPRenderer, BrowsableAPIRenderer, UnicodeJSONRenderer from rest_framework.parsers import YAMLParser, XMLParser from rest_framework.settings import api_settings -from rest_framework.compat import StringIO -from rest_framework.compat import six +from rest_framework.test import APIRequestFactory import datetime import pickle import re @@ -121,7 +119,7 @@ class POSTDeniedView(APIView): class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self): view = POSTDeniedView.as_view() - request = RequestFactory().get('/') + request = APIRequestFactory().get('/') response = view(request).render() self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<') diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index a5c5e84ce..8d64d79f2 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -6,7 +6,6 @@ from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware from django.test import TestCase, Client -from django.test.client import RequestFactory from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns @@ -19,12 +18,13 @@ from rest_framework.parsers import ( from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView from rest_framework.compat import six import json -factory = RequestFactory() +factory = APIRequestFactory() class PlainTextParser(BaseParser): @@ -116,16 +116,7 @@ class TestContentParsing(TestCase): Ensure request.DATA returns content for PUT request with form content. """ data = {'qwerty': 'uiop'} - - from django import VERSION - - if VERSION >= (1, 5): - from django.test.client import MULTIPART_CONTENT, BOUNDARY, encode_multipart - request = Request(factory.put('/', encode_multipart(BOUNDARY, data), - content_type=MULTIPART_CONTENT)) - else: - request = Request(factory.put('/', data)) - + request = Request(factory.put('/', data)) request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(list(request.DATA.items()), list(data.items())) diff --git a/rest_framework/tests/test_reverse.py b/rest_framework/tests/test_reverse.py index 93ef56377..690a30b11 100644 --- a/rest_framework/tests/test_reverse.py +++ b/rest_framework/tests/test_reverse.py @@ -1,10 +1,10 @@ from __future__ import unicode_literals from django.test import TestCase -from django.test.client import RequestFactory from rest_framework.compat import patterns, url from rest_framework.reverse import reverse +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() def null_view(request): diff --git a/rest_framework/tests/test_routers.py b/rest_framework/tests/test_routers.py index d375f4a8c..5fcccb741 100644 --- a/rest_framework/tests/test_routers.py +++ b/rest_framework/tests/test_routers.py @@ -1,15 +1,15 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase -from django.test.client import RequestFactory from django.core.exceptions import ImproperlyConfigured from rest_framework import serializers, viewsets, permissions from rest_framework.compat import include, patterns, url from rest_framework.decorators import link, action from rest_framework.response import Response from rest_framework.routers import SimpleRouter, DefaultRouter +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() urlpatterns = patterns('',) @@ -193,6 +193,7 @@ class TestActionKeywordArgs(TestCase): {'permission_classes': [permissions.AllowAny]} ) + class TestActionAppliedToExistingRoute(TestCase): """ Ensure `@action` decorator raises an except when applied diff --git a/rest_framework/tests/test_throttling.py b/rest_framework/tests/test_throttling.py index d35d37092..19bc691ae 100644 --- a/rest_framework/tests/test_throttling.py +++ b/rest_framework/tests/test_throttling.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.test import TestCase from django.contrib.auth.models import User from django.core.cache import cache -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView from rest_framework.throttling import UserRateThrottle, ScopedRateThrottle from rest_framework.response import Response @@ -41,7 +41,7 @@ class ThrottlingTests(TestCase): Reset the cache so that no throttles will be active """ cache.clear() - self.factory = RequestFactory() + self.factory = APIRequestFactory() def test_requests_are_throttled(self): """ @@ -173,7 +173,7 @@ class ScopedRateThrottleTests(TestCase): return Response('y') self.throttle_class = XYScopedRateThrottle - self.factory = RequestFactory() + self.factory = APIRequestFactory() self.x_view = XView.as_view() self.y_view = YView.as_view() self.unscoped_view = UnscopedView.as_view() diff --git a/rest_framework/tests/test_urlpatterns.py b/rest_framework/tests/test_urlpatterns.py index 29ed4a961..8132ec4c8 100644 --- a/rest_framework/tests/test_urlpatterns.py +++ b/rest_framework/tests/test_urlpatterns.py @@ -2,7 +2,7 @@ from __future__ import unicode_literals from collections import namedtuple from django.core import urlresolvers from django.test import TestCase -from django.test.client import RequestFactory +from rest_framework.test import APIRequestFactory from rest_framework.compat import patterns, url, include from rest_framework.urlpatterns import format_suffix_patterns @@ -20,7 +20,7 @@ class FormatSuffixTests(TestCase): Tests `format_suffix_patterns` against different URLPatterns to ensure the URLs still resolve properly, including any captured parameters. """ def _resolve_urlpatterns(self, urlpatterns, test_paths): - factory = RequestFactory() + factory = APIRequestFactory() try: urlpatterns = format_suffix_patterns(urlpatterns) except Exception: diff --git a/rest_framework/tests/test_validation.py b/rest_framework/tests/test_validation.py index a6ec0e993..ebfdff9cd 100644 --- a/rest_framework/tests/test_validation.py +++ b/rest_framework/tests/test_validation.py @@ -2,10 +2,9 @@ from __future__ import unicode_literals from django.db import models from django.test import TestCase from rest_framework import generics, serializers, status -from rest_framework.tests.utils import RequestFactory -import json +from rest_framework.test import APIRequestFactory -factory = RequestFactory() +factory = APIRequestFactory() # Regression for #666 @@ -33,8 +32,7 @@ class TestPreSaveValidationExclusions(TestCase): validation on read only fields. """ obj = ValidationModel.objects.create(blank_validated_field='') - request = factory.put('/', json.dumps({}), - content_type='application/json') + request = factory.put('/', {}, format='json') view = UpdateValidationModel().as_view() response = view(request, pk=obj.pk).render() self.assertEqual(response.status_code, status.HTTP_200_OK) diff --git a/rest_framework/tests/test_views.py b/rest_framework/tests/test_views.py index 2767d24c8..c0bec5aed 100644 --- a/rest_framework/tests/test_views.py +++ b/rest_framework/tests/test_views.py @@ -1,17 +1,15 @@ from __future__ import unicode_literals import copy - from django.test import TestCase -from django.test.client import RequestFactory - from rest_framework import status from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.settings import api_settings +from rest_framework.test import APIRequestFactory from rest_framework.views import APIView -factory = RequestFactory() +factory = APIRequestFactory() class BasicView(APIView): From f585480ee10f4b5e61db4ac343b1d2af25d2de97 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 28 Jun 2013 17:50:30 +0100 Subject: [PATCH 02/12] Added APIClient --- rest_framework/test.py | 83 +++++++++++++++++---- rest_framework/tests/test_authentication.py | 44 +++++------ rest_framework/tests/test_request.py | 6 +- 3 files changed, 94 insertions(+), 39 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index 92281cafc..9fce2c08b 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,39 +1,54 @@ -from rest_framework.compat import six, RequestFactory +# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# to make it harder for the user to import the wrong thing without realizing. +from django.conf import settings +from django.test.client import Client as DjangoClient +from rest_framework.compat import RequestFactory as DjangoRequestFactory +from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.renderers import JSONRenderer, MultiPartRenderer -class APIRequestFactory(RequestFactory): +class APIRequestFactory(DjangoRequestFactory): renderer_classes = { 'json': JSONRenderer, 'form': MultiPartRenderer } default_format = 'form' - def __init__(self, format=None, **defaults): - self.format = format or self.default_format - super(APIRequestFactory, self).__init__(**defaults) + def _encode_data(self, data, format=None, content_type=None): + """ + Encode the data returning a two tuple of (bytes, content_type) + """ - def _encode_data(self, data, format, content_type): if not data: return ('', None) - format = format or self.format + assert format is None or content_type is None, ( + 'You may not set both `format` and `content_type`.' + ) - if content_type is None and data is not None: + if content_type: + # Content type specified explicitly, treat data as a raw bytestring + ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) + + else: + # Use format and render the data into a bytestring + format = format or self.default_format renderer = self.renderer_classes[format]() - data = renderer.render(data) - # Determine the content-type header + ret = renderer.render(data) + + # Determine the content-type header from the renderer if ';' in renderer.media_type: content_type = renderer.media_type else: content_type = "{0}; charset={1}".format( renderer.media_type, renderer.charset ) - # Coerce text to bytes if required. - if isinstance(data, six.text_type): - data = bytes(data.encode(renderer.charset)) - return data, content_type + # Coerce text to bytes if required. + if isinstance(ret, six.text_type): + ret = bytes(ret.encode(renderer.charset)) + + return ret, content_type def post(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) @@ -46,3 +61,43 @@ class APIRequestFactory(RequestFactory): def patch(self, path, data=None, format=None, content_type=None, **extra): data, content_type = self._encode_data(data, format, content_type) return self.generic('PATCH', path, data, content_type, **extra) + + def delete(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('DELETE', path, data, content_type, **extra) + + def options(self, path, data=None, format=None, content_type=None, **extra): + data, content_type = self._encode_data(data, format, content_type) + return self.generic('OPTIONS', path, data, content_type, **extra) + + +class APIClient(APIRequestFactory, DjangoClient): + def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): + response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response diff --git a/rest_framework/tests/test_authentication.py b/rest_framework/tests/test_authentication.py index f2c51c68f..a44813b69 100644 --- a/rest_framework/tests/test_authentication.py +++ b/rest_framework/tests/test_authentication.py @@ -1,7 +1,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.http import HttpResponse -from django.test import Client, TestCase +from django.test import TestCase from django.utils import unittest from rest_framework import HTTP_HEADER_ENCODING from rest_framework import exceptions @@ -21,12 +21,11 @@ from rest_framework.authtoken.models import Token from rest_framework.compat import patterns, url, include from rest_framework.compat import oauth2_provider, oauth2_provider_models, oauth2_provider_scope from rest_framework.compat import oauth, oauth_provider -from rest_framework.test import APIRequestFactory +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView import base64 import time import datetime -import json factory = APIRequestFactory() @@ -68,7 +67,7 @@ class BasicAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -87,7 +86,7 @@ class BasicAuthTests(TestCase): credentials = ('%s:%s' % (self.username, self.password)) base64_credentials = base64.b64encode(credentials.encode(HTTP_HEADER_ENCODING)).decode(HTTP_HEADER_ENCODING) auth = 'Basic %s' % base64_credentials - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_basic_auth(self): @@ -97,7 +96,7 @@ class BasicAuthTests(TestCase): def test_post_json_failing_basic_auth(self): """Ensure POSTing json over basic auth without correct credentials fails""" - response = self.csrf_client.post('/basic/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/basic/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) self.assertEqual(response['WWW-Authenticate'], 'Basic realm="api"') @@ -107,8 +106,8 @@ class SessionAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) - self.non_csrf_client = Client(enforce_csrf_checks=False) + self.csrf_client = APIClient(enforce_csrf_checks=True) + self.non_csrf_client = APIClient(enforce_csrf_checks=False) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -154,7 +153,7 @@ class TokenAuthTests(TestCase): urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -172,7 +171,7 @@ class TokenAuthTests(TestCase): def test_post_json_passing_token_auth(self): """Ensure POSTing form over token auth with correct credentials passes and does not require CSRF""" auth = "Token " + self.key - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json', HTTP_AUTHORIZATION=auth) + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, status.HTTP_200_OK) def test_post_form_failing_token_auth(self): @@ -182,7 +181,7 @@ class TokenAuthTests(TestCase): def test_post_json_failing_token_auth(self): """Ensure POSTing json over token auth without correct credentials fails""" - response = self.csrf_client.post('/token/', json.dumps({'example': 'example'}), 'application/json') + response = self.csrf_client.post('/token/', {'example': 'example'}, format='json') self.assertEqual(response.status_code, status.HTTP_401_UNAUTHORIZED) def test_token_has_auto_assigned_key_if_none_provided(self): @@ -193,33 +192,33 @@ class TokenAuthTests(TestCase): def test_token_login_json(self): """Ensure token login view using JSON POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': self.password}), 'application/json') + {'username': self.username, 'password': self.password}, format='json') self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) def test_token_login_json_bad_creds(self): """Ensure token login view using JSON POST fails if bad credentials are used.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username, 'password': "badpass"}), 'application/json') + {'username': self.username, 'password': "badpass"}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_json_missing_fields(self): """Ensure token login view using JSON POST fails if missing fields.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', - json.dumps({'username': self.username}), 'application/json') + {'username': self.username}, format='json') self.assertEqual(response.status_code, 400) def test_token_login_form(self): """Ensure token login view using form POST works.""" - client = Client(enforce_csrf_checks=True) + client = APIClient(enforce_csrf_checks=True) response = client.post('/auth-token/', {'username': self.username, 'password': self.password}) self.assertEqual(response.status_code, status.HTTP_200_OK) - self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) + self.assertEqual(response.data['token'], self.key) class IncorrectCredentialsTests(TestCase): @@ -256,7 +255,7 @@ class OAuthTests(TestCase): self.consts = consts - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' @@ -470,12 +469,13 @@ class OAuthTests(TestCase): response = self.csrf_client.post('/oauth/', HTTP_AUTHORIZATION=auth) self.assertEqual(response.status_code, 401) + class OAuth2Tests(TestCase): """OAuth 2.0 authentication""" urls = 'rest_framework.tests.test_authentication' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' diff --git a/rest_framework/tests/test_request.py b/rest_framework/tests/test_request.py index 8d64d79f2..969d8024a 100644 --- a/rest_framework/tests/test_request.py +++ b/rest_framework/tests/test_request.py @@ -5,7 +5,7 @@ from __future__ import unicode_literals from django.contrib.auth.models import User from django.contrib.auth import authenticate, login, logout from django.contrib.sessions.middleware import SessionMiddleware -from django.test import TestCase, Client +from django.test import TestCase from rest_framework import status from rest_framework.authentication import SessionAuthentication from rest_framework.compat import patterns @@ -18,7 +18,7 @@ from rest_framework.parsers import ( from rest_framework.request import Request from rest_framework.response import Response from rest_framework.settings import api_settings -from rest_framework.test import APIRequestFactory +from rest_framework.test import APIRequestFactory, APIClient from rest_framework.views import APIView from rest_framework.compat import six import json @@ -248,7 +248,7 @@ class TestContentParsingWithAuthentication(TestCase): urls = 'rest_framework.tests.test_request' def setUp(self): - self.csrf_client = Client(enforce_csrf_checks=True) + self.csrf_client = APIClient(enforce_csrf_checks=True) self.username = 'john' self.email = 'lennon@thebeatles.com' self.password = 'password' From 90bc07f3f160485001ea329e5f69f7e521d14ec9 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 08:05:08 +0100 Subject: [PATCH 03/12] Addeded 'APITestClient.credentials()' --- rest_framework/test.py | 29 +++++++++++++++++++++++++ rest_framework/tests/test_testing.py | 32 ++++++++++++++++++++++++++++ 2 files changed, 61 insertions(+) create mode 100644 rest_framework/tests/test_testing.py diff --git a/rest_framework/test.py b/rest_framework/test.py index 9fce2c08b..8115fa0d2 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,5 +1,8 @@ +# -- coding: utf-8 -- + # Note that we use `DjangoRequestFactory` and `DjangoClient` names in order # to make it harder for the user to import the wrong thing without realizing. +from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from rest_framework.compat import RequestFactory as DjangoRequestFactory @@ -72,31 +75,57 @@ class APIRequestFactory(DjangoRequestFactory): class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, *args, **kwargs): + self._credentials = {} + super(APIClient, self).__init__(*args, **kwargs) + + def credentials(self, **kwargs): + self._credentials = kwargs + + def get(self, path, data={}, follow=False, **extra): + extra.update(self._credentials) + response = super(APIClient, self).get(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + + def head(self, path, data={}, follow=False, **extra): + extra.update(self._credentials) + response = super(APIClient, self).head(path, data=data, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response + def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) return response def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): + extra.update(self._credentials) response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) if follow: response = self._handle_redirects(response, **extra) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py new file mode 100644 index 000000000..71dacd38e --- /dev/null +++ b/rest_framework/tests/test_testing.py @@ -0,0 +1,32 @@ +# -- coding: utf-8 -- + +from __future__ import unicode_literals +from django.test import TestCase +from rest_framework.compat import patterns, url +from rest_framework.decorators import api_view +from rest_framework.response import Response +from rest_framework.test import APIClient + + +@api_view(['GET']) +def mirror(request): + return Response({ + 'auth': request.META.get('HTTP_AUTHORIZATION', b'') + }) + + +urlpatterns = patterns('', + url(r'^view/$', mirror), +) + + +class CheckTestClient(TestCase): + urls = 'rest_framework.tests.test_testing' + + def setUp(self): + self.client = APIClient() + + def test_credentials(self): + self.client.credentials(HTTP_AUTHORIZATION='example') + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') From f7db06953bd8ad7f5e0211f49a04e8d5bb634380 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 08:06:11 +0100 Subject: [PATCH 04/12] Remove unneeded tests.utils, superseeded by APIRequestFactory, APIClient --- rest_framework/tests/utils.py | 40 ----------------------------------- 1 file changed, 40 deletions(-) delete mode 100644 rest_framework/tests/utils.py diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py deleted file mode 100644 index 8c87917d9..000000000 --- a/rest_framework/tests/utils.py +++ /dev/null @@ -1,40 +0,0 @@ -from __future__ import unicode_literals -from django.test.client import FakePayload, Client as _Client, RequestFactory as _RequestFactory -from django.test.client import MULTIPART_CONTENT -from rest_framework.compat import urlparse - - -class RequestFactory(_RequestFactory): - - def __init__(self, **defaults): - super(RequestFactory, self).__init__(**defaults) - - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - **extra): - "Construct a PATCH request." - - patch_data = self._encode_data(data, content_type) - - parsed = urlparse.urlparse(path) - r = { - 'CONTENT_LENGTH': len(patch_data), - 'CONTENT_TYPE': content_type, - 'PATH_INFO': self._get_path(parsed), - 'QUERY_STRING': parsed[4], - 'REQUEST_METHOD': 'PATCH', - 'wsgi.input': FakePayload(patch_data), - } - r.update(extra) - return self.request(**r) - - -class Client(_Client, RequestFactory): - def patch(self, path, data={}, content_type=MULTIPART_CONTENT, - follow=False, **extra): - """ - Send a resource to the server using PATCH. - """ - response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response From 35022ca9213939a2f40c82facffa908a818efe0b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 08:14:05 +0100 Subject: [PATCH 05/12] Refactor SessionAuthentication slightly --- rest_framework/authentication.py | 26 ++++++++++++++++---------- 1 file changed, 16 insertions(+), 10 deletions(-) diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index 102980271..b42162dd9 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -26,6 +26,12 @@ def get_authorization_header(request): return auth +class CSRFCheck(CsrfViewMiddleware): + def _reject(self, request, reason): + # Return the failure reason instead of an HttpResponse + return reason + + class BaseAuthentication(object): """ All authentication classes should extend BaseAuthentication. @@ -110,20 +116,20 @@ class SessionAuthentication(BaseAuthentication): if not user or not user.is_active: return None - # Enforce CSRF validation for session based authentication. - class CSRFCheck(CsrfViewMiddleware): - def _reject(self, request, reason): - # Return the failure reason instead of an HttpResponse - return reason - - reason = CSRFCheck().process_view(http_request, None, (), {}) - if reason: - # CSRF failed, bail with explicit error message - raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) + self.enforce_csrf(http_request) # CSRF passed with authenticated user return (user, None) + def enforce_csrf(self, request): + """ + Enforce CSRF validation for session based authentication. + """ + reason = CSRFCheck().process_view(request, None, (), {}) + if reason: + # CSRF failed, bail with explicit error message + raise exceptions.AuthenticationFailed('CSRF Failed: %s' % reason) + class TokenAuthentication(BaseAuthentication): """ From 664f8c63655770cd90bdbd510b315bcd045b380a Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 21:02:58 +0100 Subject: [PATCH 06/12] Added APIClient.authenticate() --- rest_framework/renderers.py | 2 +- rest_framework/request.py | 20 +++++++++++++ rest_framework/test.py | 39 +++++++++++++++++++++++--- rest_framework/tests/test_testing.py | 42 ++++++++++++++++++++++++++-- 4 files changed, 95 insertions(+), 8 deletions(-) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index d7a7ef297..3a03ca332 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -576,7 +576,7 @@ class BrowsableAPIRenderer(BaseRenderer): class MultiPartRenderer(BaseRenderer): media_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' - format = 'form' + format = 'multipart' charset = 'utf-8' BOUNDARY = 'BoUnDaRyStRiNg' diff --git a/rest_framework/request.py b/rest_framework/request.py index 0d88ebc7e..919716f49 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -64,6 +64,20 @@ def clone_request(request, method): return ret +class ForcedAuthentication(object): + """ + This authentication class is used if the test client or request factory + forcibly authenticated the request. + """ + + def __init__(self, force_user, force_token): + self.force_user = force_user + self.force_token = force_token + + def authenticate(self, request): + return (self.force_user, self.force_token) + + class Request(object): """ Wrapper allowing to enhance a standard `HttpRequest` instance. @@ -98,6 +112,12 @@ class Request(object): self.parser_context['request'] = self self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET + force_user = getattr(request, '_force_auth_user', None) + force_token = getattr(request, '_force_auth_token', None) + if (force_user is not None or force_token is not None): + forced_auth = ForcedAuthentication(force_user, force_token) + self.authenticators = (forced_auth,) + def _default_negotiator(self): return api_settings.DEFAULT_CONTENT_NEGOTIATION_CLASS() diff --git a/rest_framework/test.py b/rest_framework/test.py index 8115fa0d2..08de2297b 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -5,6 +5,7 @@ from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient +from django.test.client import ClientHandler from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six from rest_framework.renderers import JSONRenderer, MultiPartRenderer @@ -13,9 +14,9 @@ from rest_framework.renderers import JSONRenderer, MultiPartRenderer class APIRequestFactory(DjangoRequestFactory): renderer_classes = { 'json': JSONRenderer, - 'form': MultiPartRenderer + 'multipart': MultiPartRenderer } - default_format = 'form' + default_format = 'multipart' def _encode_data(self, data, format=None, content_type=None): """ @@ -74,14 +75,44 @@ class APIRequestFactory(DjangoRequestFactory): return self.generic('OPTIONS', path, data, content_type, **extra) -class APIClient(APIRequestFactory, DjangoClient): +class ForceAuthClientHandler(ClientHandler): + """ + A patched version of ClientHandler that can enforce authentication + on the outgoing requests. + """ + def __init__(self, *args, **kwargs): + self._force_auth_user = None + self._force_auth_token = None + super(ForceAuthClientHandler, self).__init__(*args, **kwargs) + + def force_authenticate(self, user=None, token=None): + self._force_auth_user = user + self._force_auth_token = token + + def get_response(self, request): + # This is the simplest place we can hook into to patch the + # request object. + request._force_auth_user = self._force_auth_user + request._force_auth_token = self._force_auth_token + return super(ForceAuthClientHandler, self).get_response(request) + + +class APIClient(APIRequestFactory, DjangoClient): + def __init__(self, enforce_csrf_checks=False, **defaults): + # Note that our super call skips Client.__init__ + # since we don't need to instantiate a regular ClientHandler + super(DjangoClient, self).__init__(**defaults) + self.handler = ForceAuthClientHandler(enforce_csrf_checks) + self.exc_info = None self._credentials = {} - super(APIClient, self).__init__(*args, **kwargs) def credentials(self, **kwargs): self._credentials = kwargs + def authenticate(self, user=None, token=None): + self.handler.force_authenticate(user, token) + def get(self, path, data={}, follow=False, **extra): extra.update(self._credentials) response = super(APIClient, self).get(path, data=data, **extra) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 71dacd38e..a8398b9a4 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -1,6 +1,7 @@ # -- coding: utf-8 -- from __future__ import unicode_literals +from django.contrib.auth.models import User from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.decorators import api_view @@ -8,10 +9,11 @@ from rest_framework.response import Response from rest_framework.test import APIClient -@api_view(['GET']) +@api_view(['GET', 'POST']) def mirror(request): return Response({ - 'auth': request.META.get('HTTP_AUTHORIZATION', b'') + 'auth': request.META.get('HTTP_AUTHORIZATION', b''), + 'user': request.user.username }) @@ -27,6 +29,40 @@ class CheckTestClient(TestCase): self.client = APIClient() def test_credentials(self): + """ + Setting `.credentials()` adds the required headers to each request. + """ self.client.credentials(HTTP_AUTHORIZATION='example') + for _ in range(0, 3): + response = self.client.get('/view/') + self.assertEqual(response.data['auth'], 'example') + + def test_authenticate(self): + """ + Setting `.authenticate()` forcibly authenticates each request. + """ + user = User.objects.create_user('example', 'example@example.com') + self.client.authenticate(user) response = self.client.get('/view/') - self.assertEqual(response.data['auth'], 'example') + self.assertEqual(response.data['user'], 'example') + + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + User.objects.create_user('example', 'example@example.com', 'password') + self.client.login(username='example', password='password') + response = self.client.post('/view/') + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + client = APIClient(enforce_csrf_checks=True) + User.objects.create_user('example', 'example@example.com', 'password') + client.login(username='example', password='password') + response = client.post('/view/') + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) From ab799ccc3ee473de61ec35c6f745c6952752c522 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 21:34:47 +0100 Subject: [PATCH 07/12] Simplify APIClient implementation --- rest_framework/authentication.py | 6 +-- rest_framework/test.py | 66 +++++++------------------------- 2 files changed, 16 insertions(+), 56 deletions(-) diff --git a/rest_framework/authentication.py b/rest_framework/authentication.py index b42162dd9..cf001a24d 100644 --- a/rest_framework/authentication.py +++ b/rest_framework/authentication.py @@ -109,14 +109,14 @@ class SessionAuthentication(BaseAuthentication): """ # Get the underlying HttpRequest object - http_request = request._request - user = getattr(http_request, 'user', None) + request = request._request + user = getattr(request, 'user', None) # Unauthenticated, CSRF validation not required if not user or not user.is_active: return None - self.enforce_csrf(http_request) + self.enforce_csrf(request) # CSRF passed with authenticated user return (user, None) diff --git a/rest_framework/test.py b/rest_framework/test.py index 08de2297b..2e9cfe096 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -86,10 +86,6 @@ class ForceAuthClientHandler(ClientHandler): self._force_auth_token = None super(ForceAuthClientHandler, self).__init__(*args, **kwargs) - def force_authenticate(self, user=None, token=None): - self._force_auth_user = user - self._force_auth_token = token - def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. @@ -108,56 +104,20 @@ class APIClient(APIRequestFactory, DjangoClient): self._credentials = {} def credentials(self, **kwargs): + """ + Sets headers that will be used on every outgoing request. + """ self._credentials = kwargs def authenticate(self, user=None, token=None): - self.handler.force_authenticate(user, token) + """ + Forcibly authenticates outgoing requests with the given + user and/or token. + """ + self.handler._force_auth_user = user + self.handler._force_auth_token = token - def get(self, path, data={}, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).get(path, data=data, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def head(self, path, data={}, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).head(path, data=data, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def post(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def put(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def patch(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def delete(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response - - def options(self, path, data=None, format=None, content_type=None, follow=False, **extra): - extra.update(self._credentials) - response = super(APIClient, self).post(path, data=data, format=format, content_type=content_type, **extra) - if follow: - response = self._handle_redirects(response, **extra) - return response + def request(self, **request): + # Ensure that any credentials set get added to every request. + request.update(self._credentials) + return super(APIClient, self).request(**request) From c9485c783a555516e41068996258f4c5e383523b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 22:53:15 +0100 Subject: [PATCH 08/12] Rename to force_authenticate --- rest_framework/test.py | 2 +- rest_framework/tests/test_testing.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index 2e9cfe096..2f658a56c 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -109,7 +109,7 @@ class APIClient(APIRequestFactory, DjangoClient): """ self._credentials = kwargs - def authenticate(self, user=None, token=None): + def force_authenticate(self, user=None, token=None): """ Forcibly authenticates outgoing requests with the given user and/or token. diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index a8398b9a4..3706f38c2 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -37,12 +37,12 @@ class CheckTestClient(TestCase): response = self.client.get('/view/') self.assertEqual(response.data['auth'], 'example') - def test_authenticate(self): + def test_force_authenticate(self): """ - Setting `.authenticate()` forcibly authenticates each request. + Setting `.force_authenticate()` forcibly authenticates each request. """ user = User.objects.create_user('example', 'example@example.com') - self.client.authenticate(user) + self.client.force_authenticate(user) response = self.client.get('/view/') self.assertEqual(response.data['user'], 'example') From d31d7c18676b6292e8dc688b61913d572eccde91 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Sat, 29 Jun 2013 22:53:27 +0100 Subject: [PATCH 09/12] First pass at testing docs --- docs/api-guide/testing.md | 97 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 97 insertions(+) create mode 100644 docs/api-guide/testing.md diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md new file mode 100644 index 000000000..f3a092a9d --- /dev/null +++ b/docs/api-guide/testing.md @@ -0,0 +1,97 @@ + + +# Testing + +> Code without tests is broken as designed +> +> — [Jacob Kaplan-Moss][cite] + +REST framework includes a few helper classes that extend Django's existing test framework, and improve support for making API requests. + +# APIRequestFactory + +Extends Django's existing `RequestFactory`. + +**TODO**: Document making requests. Note difference on form PUT requests. Document configuration. + +# APIClient + +Extends Django's existing `Client`. + +### .login(**kwargs) + +The `login` method functions exactly as it does with Django's regular `Client` class. This allows you to authenticate requests against any views which include `SessionAuthentication`. + + # Make all requests in the context of a logged in session. + >>> client = APIClient() + >>> client.login(username='lauren', password='secret') + +To logout, call the `logout` method as usual. + + # Log out + >>> client.logout() + +The `login` method is appropriate for testing APIs that use session authentication, for example web sites which include AJAX interaction with the API. + +### .credentials(**kwargs) + +The `credentials` method can be used to set headers that will then be included on all subsequent requests by the test client. + + # Include an appropriate `Authorization:` header on all requests. + >>> token = Token.objects.get(username='lauren') + >>> client = APIClient() + >>> client.credentials(HTTP_AUTHORIZATION='Token ' + token.key) + +Note that calling `credentials` a second time overwrites any existing credentials. You can unset any existing credentials by calling the method with no arguments. + + # Stop including any credentials + >>> client.credentials() + +The `credentials` method is appropriate for testing APIs that require authentication headers, such as basic authentication, OAuth1a and OAuth2 authentication, and simple token authentication schemes. + +### .force_authenticate(user=None, token=None) + +Sometimes you may want to bypass authentication, and simple force all requests by the test client to be automatically treated as authenticated. + +This can be a useful shortcut if you're testing the API but don't want to have to construct valid authentication credentials in order to make test requests. + + >>> user = User.objects.get(username='lauren') + >>> client = APIClient() + >>> client.force_authenticate(user=user) + +To unauthenticate subsequant requests, call `force_authenticate` setting the user and/or token to `None`. + + >>> client.force_authenticate(user=None) + +### Making requests + +**TODO**: Document requests similarly to `APIRequestFactory` + +# Testing responses + +### Using request.data + +When checking the validity of test responses it's often more convenient to inspect the data that the response was created with, rather than inspecting the fully rendered response. + +For example, it's easier to inspect `request.data`: + + response = self.client.get('/users/4/') + self.assertEqual(response.data, {'id': 4, 'username': 'lauren'}) + +Instead of inspecting the result of parsing `request.content`: + + response = self.client.get('/users/4/') + self.assertEqual(json.loads(response.content), {'id': 4, 'username': 'lauren'}) + +### Rendering responses + +If you're testing views directly using `APIRequestFactory`, the responses that are returned will not yet be rendered, as rendering of template responses is performed by Django's internal request-response cycle. In order to access `response.content`, you'll first need to render the response. + + view = UserDetail.as_view() + request = factory.get('/users/4') + response = view(request, pk='4') + response.render() # Cannot access `response.content` without this. + self.assertEqual(response.content, '{"username": "lauren", "id": 4}') + + +[cite]: http://jacobian.org/writing/django-apps-with-buildout/#s-create-a-test-wrapper \ No newline at end of file From 0a722de171b0e80ac26d8c77b8051a4170bdb4c6 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Jul 2013 13:59:05 +0100 Subject: [PATCH 10/12] Complete testing docs --- docs/api-guide/renderers.md | 11 +++ docs/api-guide/settings.md | 27 +++++ docs/api-guide/testing.md | 143 +++++++++++++++++++++++++-- rest_framework/response.py | 2 +- rest_framework/settings.py | 8 ++ rest_framework/test.py | 70 ++++++++----- rest_framework/tests/test_testing.py | 55 ++++++++++- 7 files changed, 274 insertions(+), 42 deletions(-) diff --git a/docs/api-guide/renderers.md b/docs/api-guide/renderers.md index b627c9306..869bdc16a 100644 --- a/docs/api-guide/renderers.md +++ b/docs/api-guide/renderers.md @@ -217,6 +217,16 @@ Renders data into HTML for the Browsable API. This renderer will determine whic **.charset**: `utf-8` +## MultiPartRenderer + +This renderer is used for rendering HTML multipart form data. **It is not suitable as a response renderer**, but is instead used for creating test requests, using REST framework's [test client and test request factory][testing]. + +**.media_type**: `multipart/form-data; boundary=BoUnDaRyStRiNg` + +**.format**: `'.multipart'` + +**.charset**: `utf-8` + --- # Custom renderers @@ -373,6 +383,7 @@ Comma-separated values are a plain-text tabular data format, that can be easily [rfc4627]: http://www.ietf.org/rfc/rfc4627.txt [cors]: http://www.w3.org/TR/cors/ [cors-docs]: ../topics/ajax-csrf-cors.md +[testing]: testing.md [HATEOAS]: http://timelessrepo.com/haters-gonna-hateoas [quote]: http://roy.gbiv.com/untangled/2008/rest-apis-must-be-hypertext-driven [application/vnd.github+json]: http://developer.github.com/v3/media/ diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index 4a5164c9b..7b114983d 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -149,6 +149,33 @@ Default: `None` --- +## Test settings + +*The following settings control the behavior of APIRequestFactory and APIClient* + +#### TEST_REQUEST_DEFAULT_FORMAT + +The default format that should be used when making test requests. + +This should match up with the format of one of the renderer classes in the `TEST_REQUEST_RENDERER_CLASSES` setting. + +Default: `'multipart'` + +#### TEST_REQUEST_RENDERER_CLASSES + +The renderer classes that are supported when building test requests. + +The format of any of these renderer classes may be used when contructing a test request, for example: `client.post('/users', {'username': 'jamie'}, format='json')` + +Default: + + ( + 'rest_framework.renderers.MultiPartRenderer', + 'rest_framework.renderers.JSONRenderer' + ) + +--- + ## Browser overrides *The following settings provide URL or form-based overrides of the default browser behavior.* diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index f3a092a9d..293ee7019 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -10,13 +10,100 @@ REST framework includes a few helper classes that extend Django's existing test # APIRequestFactory -Extends Django's existing `RequestFactory`. +Extends [Django's existing `RequestFactory` class][requestfactory]. -**TODO**: Document making requests. Note difference on form PUT requests. Document configuration. +## Creating test requests + +The `APIRequestFactory` class supports an almost identical API to Django's standard `RequestFactory` class. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available. + +### Using the format arguments + +Methods which create a request body, such as `post`, `put` and `patch`, include a `format` argument, which make it easy to generate requests using a content type other than multipart form data. For example: + + factory = APIRequestFactory() + request = factory.post('/notes/', {'title': 'new idea'}, format='json') + +By default the available formats are `'multipart'` and `'json'`. For compatibility with Django's existing `RequestFactory` the default format is `'multipart'`. + +To support a wider set of request formats, or change the default format, [see the configuration section][configuration]. + +If you need to explictly encode the request body, you can do so by explicitly setting the `content_type` flag. For example: + + request = factory.post('/notes/', json.dumps({'title': 'new idea'}), content_type='application/json') + +### PUT and PATCH with form data + +One difference worth noting between Django's `RequestFactory` and REST framework's `APIRequestFactory` is that multipart form data will be encoded for methods other than just `.post()`. + +For example, using `APIRequestFactory`, you can make a form PUT request like so: + + factory = APIRequestFactory() + request = factory.put('/notes/547/', {'title': 'remember to email dave'}) + +Using Django's `Factory`, you'd need to explicitly encode the data yourself: + + factory = RequestFactory() + data = {'title': 'remember to email dave'} + content = encode_multipart('BoUnDaRyStRiNg', data) + content_type = 'multipart/form-data; boundary=BoUnDaRyStRiNg' + request = factory.put('/notes/547/', content, content_type=content_type) + +## Forcing authentication + +When testing views directly using a request factory, it's often convenient to be able to directly authenticate the request, rather than having to construct the correct authentication credentials. + +To forcibly authenticate a request, use the `force_authenticate()` method. + + factory = APIRequestFactory() + user = User.objects.get(username='olivia') + view = AccountDetail.as_view() + + # Make an authenticated request to the view... + request = factory.get('/accounts/django-superstars/') + force_authenticate(request, user=user) + response = view(request) + +The signature for the method is `force_authenticate(request, user=None, token=None)`. When making the call, either or both of the user and token may be set. + +--- + +**Note**: When using `APIRequestFactory`, the object that is returned is Django's standard `HttpRequest`, and not REST framework's `Request` object, which is only generated once the view is called. + +This means that setting attributes directly on the request object may not always have the effect you expect. For example, setting `.token` directly will have no effect, and setting `.user` directly will only work if session authentication is being used. + + # Request will only authenticate if `SessionAuthentication` is in use. + request = factory.get('/accounts/django-superstars/') + request.user = user + response = view(request) + +--- + +## Forcing CSRF validation + +By default, requests created with `APIRequestFactory` will not have CSRF validation applied when passed to a REST framework view. If you need to explicitly turn CSRF validation on, you can do so by setting the `enforce_csrf_checks` flag when instantiating the factory. + + factory = APIRequestFactory(enforce_csrf_checks=True) + +--- + +**Note**: It's worth noting that Django's standard `RequestFactory` doesn't need to include this option, because when using regular Django the CSRF validation takes place in middleware, which is not run when testing views directly. When using REST framework, CSRF validation takes place inside the view, so the request factory needs to disable view-level CSRF checks. + +--- # APIClient -Extends Django's existing `Client`. +Extends [Django's existing `Client` class][client]. + +## Making requests + +The `APIClient` class supports the same request interface as `APIRequestFactory`. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available. For example: + + client = APIClient() + client.post('/notes/', {'title': 'new idea'}, format='json') + +To support a wider set of request formats, or change the default format, [see the configuration section][configuration]. + +## Authenticating ### .login(**kwargs) @@ -59,17 +146,23 @@ This can be a useful shortcut if you're testing the API but don't want to have t >>> client = APIClient() >>> client.force_authenticate(user=user) -To unauthenticate subsequant requests, call `force_authenticate` setting the user and/or token to `None`. +To unauthenticate subsequent requests, call `force_authenticate` setting the user and/or token to `None`. >>> client.force_authenticate(user=None) -### Making requests +## CSRF validation -**TODO**: Document requests similarly to `APIRequestFactory` +By default CSRF validation is not applied when using `APIClient`. If you need to explicitly enable CSRF validation, you can do so by setting the `enforce_csrf_checks` flag when instantiating the client. + + client = APIClient(enforce_csrf_checks=True) + +As usual CSRF validation will only apply to any session authenticated views. This means CSRF validation will only occur if the client has been logged in by calling `login()`. + +--- # Testing responses -### Using request.data +## Checking the response data When checking the validity of test responses it's often more convenient to inspect the data that the response was created with, rather than inspecting the fully rendered response. @@ -83,7 +176,7 @@ Instead of inspecting the result of parsing `request.content`: response = self.client.get('/users/4/') self.assertEqual(json.loads(response.content), {'id': 4, 'username': 'lauren'}) -### Rendering responses +## Rendering responses If you're testing views directly using `APIRequestFactory`, the responses that are returned will not yet be rendered, as rendering of template responses is performed by Django's internal request-response cycle. In order to access `response.content`, you'll first need to render the response. @@ -92,6 +185,36 @@ If you're testing views directly using `APIRequestFactory`, the responses that a response = view(request, pk='4') response.render() # Cannot access `response.content` without this. self.assertEqual(response.content, '{"username": "lauren", "id": 4}') - -[cite]: http://jacobian.org/writing/django-apps-with-buildout/#s-create-a-test-wrapper \ No newline at end of file +--- + +# Configuration + +## Setting the default format + +The default format used to make test requests may be set using the `TEST_REQUEST_DEFAULT_FORMAT` setting key. For example, to always use JSON for test requests by default instead of standard multipart form requests, set the following in your `settings.py` file: + + REST_FRAMEWORK = { + ... + 'TEST_REQUEST_DEFAULT_FORMAT': 'json' + } + +## Setting the available formats + +If you need to test requests using something other than multipart or json requests, you can do so by setting the `TEST_REQUEST_RENDERER_CLASSES` setting. + +For example, to add support for using `format='yaml'` in test requests, you might have something like this in your `settings.py` file. + + REST_FRAMEWORK = { + ... + 'TEST_REQUEST_RENDERER_CLASSES': ( + 'rest_framework.renderers.MultiPartRenderer', + 'rest_framework.renderers.JSONRenderer', + 'rest_framework.renderers.YAMLRenderer' + ) + } + +[cite]: http://jacobian.org/writing/django-apps-with-buildout/#s-create-a-test-wrapper +[client]: https://docs.djangoproject.com/en/dev/topics/testing/overview/#module-django.test.client +[requestfactory]: https://docs.djangoproject.com/en/dev/topics/testing/advanced/#django.test.client.RequestFactory +[configuration]: #configuration diff --git a/rest_framework/response.py b/rest_framework/response.py index c4b2aaa66..5877c8a3e 100644 --- a/rest_framework/response.py +++ b/rest_framework/response.py @@ -50,7 +50,7 @@ class Response(SimpleTemplateResponse): charset = renderer.charset content_type = self.content_type - if content_type is None and charset is not None and ';' not in media_type: + if content_type is None and charset is not None: content_type = "{0}; charset={1}".format(media_type, charset) elif content_type is None: content_type = media_type diff --git a/rest_framework/settings.py b/rest_framework/settings.py index beb511aca..8fd177d58 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -73,6 +73,13 @@ DEFAULTS = { 'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser', 'UNAUTHENTICATED_TOKEN': None, + # Testing + 'TEST_REQUEST_RENDERER_CLASSES': ( + 'rest_framework.renderers.MultiPartRenderer', + 'rest_framework.renderers.JSONRenderer' + ), + 'TEST_REQUEST_DEFAULT_FORMAT': 'multipart', + # Browser enhancements 'FORM_METHOD_OVERRIDE': '_method', 'FORM_CONTENT_OVERRIDE': '_content', @@ -115,6 +122,7 @@ IMPORT_STRINGS = ( 'DEFAULT_PAGINATION_SERIALIZER_CLASS', 'DEFAULT_FILTER_BACKENDS', 'FILTER_BACKEND', + 'TEST_REQUEST_RENDERER_CLASSES', 'UNAUTHENTICATED_USER', 'UNAUTHENTICATED_TOKEN', ) diff --git a/rest_framework/test.py b/rest_framework/test.py index 2f658a56c..29d017ee4 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -1,22 +1,31 @@ # -- coding: utf-8 -- -# Note that we use `DjangoRequestFactory` and `DjangoClient` names in order +# Note that we import as `DjangoRequestFactory` and `DjangoClient` in order # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals from django.conf import settings from django.test.client import Client as DjangoClient from django.test.client import ClientHandler +from rest_framework.settings import api_settings from rest_framework.compat import RequestFactory as DjangoRequestFactory from rest_framework.compat import force_bytes_or_smart_bytes, six -from rest_framework.renderers import JSONRenderer, MultiPartRenderer + + +def force_authenticate(request, user=None, token=None): + request._force_auth_user = user + request._force_auth_token = token class APIRequestFactory(DjangoRequestFactory): - renderer_classes = { - 'json': JSONRenderer, - 'multipart': MultiPartRenderer - } - default_format = 'multipart' + renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES + default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT + + def __init__(self, enforce_csrf_checks=False, **defaults): + self.enforce_csrf_checks = enforce_csrf_checks + self.renderer_classes = {} + for cls in self.renderer_classes_list: + self.renderer_classes[cls.format] = cls + super(APIRequestFactory, self).__init__(**defaults) def _encode_data(self, data, format=None, content_type=None): """ @@ -35,18 +44,24 @@ class APIRequestFactory(DjangoRequestFactory): ret = force_bytes_or_smart_bytes(data, settings.DEFAULT_CHARSET) else: - # Use format and render the data into a bytestring format = format or self.default_format + + assert format in self.renderer_classes, ("Invalid format '{0}'. " + "Available formats are {1}. Set TEST_REQUEST_RENDERER_CLASSES " + "to enable extra request formats.".format( + format, + ', '.join(["'" + fmt + "'" for fmt in self.renderer_classes.keys()]) + ) + ) + + # Use format and render the data into a bytestring renderer = self.renderer_classes[format]() ret = renderer.render(data) # Determine the content-type header from the renderer - if ';' in renderer.media_type: - content_type = renderer.media_type - else: - content_type = "{0}; charset={1}".format( - renderer.media_type, renderer.charset - ) + content_type = "{0}; charset={1}".format( + renderer.media_type, renderer.charset + ) # Coerce text to bytes if required. if isinstance(ret, six.text_type): @@ -74,6 +89,11 @@ class APIRequestFactory(DjangoRequestFactory): data, content_type = self._encode_data(data, format, content_type) return self.generic('OPTIONS', path, data, content_type, **extra) + def request(self, **kwargs): + request = super(APIRequestFactory, self).request(**kwargs) + request._dont_enforce_csrf_checks = not self.enforce_csrf_checks + return request + class ForceAuthClientHandler(ClientHandler): """ @@ -82,25 +102,21 @@ class ForceAuthClientHandler(ClientHandler): """ def __init__(self, *args, **kwargs): - self._force_auth_user = None - self._force_auth_token = None + self._force_user = None + self._force_token = None super(ForceAuthClientHandler, self).__init__(*args, **kwargs) def get_response(self, request): # This is the simplest place we can hook into to patch the # request object. - request._force_auth_user = self._force_auth_user - request._force_auth_token = self._force_auth_token + force_authenticate(request, self._force_user, self._force_token) return super(ForceAuthClientHandler, self).get_response(request) class APIClient(APIRequestFactory, DjangoClient): def __init__(self, enforce_csrf_checks=False, **defaults): - # Note that our super call skips Client.__init__ - # since we don't need to instantiate a regular ClientHandler - super(DjangoClient, self).__init__(**defaults) + super(APIClient, self).__init__(**defaults) self.handler = ForceAuthClientHandler(enforce_csrf_checks) - self.exc_info = None self._credentials = {} def credentials(self, **kwargs): @@ -114,10 +130,10 @@ class APIClient(APIRequestFactory, DjangoClient): Forcibly authenticates outgoing requests with the given user and/or token. """ - self.handler._force_auth_user = user - self.handler._force_auth_token = token + self.handler._force_user = user + self.handler._force_token = token - def request(self, **request): + def request(self, **kwargs): # Ensure that any credentials set get added to every request. - request.update(self._credentials) - return super(APIClient, self).request(**request) + kwargs.update(self._credentials) + return super(APIClient, self).request(**kwargs) diff --git a/rest_framework/tests/test_testing.py b/rest_framework/tests/test_testing.py index 3706f38c2..49d45fc29 100644 --- a/rest_framework/tests/test_testing.py +++ b/rest_framework/tests/test_testing.py @@ -6,11 +6,11 @@ from django.test import TestCase from rest_framework.compat import patterns, url from rest_framework.decorators import api_view from rest_framework.response import Response -from rest_framework.test import APIClient +from rest_framework.test import APIClient, APIRequestFactory, force_authenticate @api_view(['GET', 'POST']) -def mirror(request): +def view(request): return Response({ 'auth': request.META.get('HTTP_AUTHORIZATION', b''), 'user': request.user.username @@ -18,11 +18,11 @@ def mirror(request): urlpatterns = patterns('', - url(r'^view/$', mirror), + url(r'^view/$', view), ) -class CheckTestClient(TestCase): +class TestAPITestClient(TestCase): urls = 'rest_framework.tests.test_testing' def setUp(self): @@ -66,3 +66,50 @@ class CheckTestClient(TestCase): expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} self.assertEqual(response.status_code, 403) self.assertEqual(response.data, expected) + + +class TestAPIRequestFactory(TestCase): + def test_csrf_exempt_by_default(self): + """ + By default, the test client is CSRF exempt. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory() + request = factory.post('/view/') + request.user = user + response = view(request) + self.assertEqual(response.status_code, 200) + + def test_explicitly_enforce_csrf_checks(self): + """ + The test client can enforce CSRF checks. + """ + user = User.objects.create_user('example', 'example@example.com', 'password') + factory = APIRequestFactory(enforce_csrf_checks=True) + request = factory.post('/view/') + request.user = user + response = view(request) + expected = {'detail': 'CSRF Failed: CSRF cookie not set.'} + self.assertEqual(response.status_code, 403) + self.assertEqual(response.data, expected) + + def test_invalid_format(self): + """ + Attempting to use a format that is not configured will raise an + assertion error. + """ + factory = APIRequestFactory() + self.assertRaises(AssertionError, factory.post, + path='/view/', data={'example': 1}, format='xml' + ) + + def test_force_authenticate(self): + """ + Setting `force_authenticate()` forcibly authenticates the request. + """ + user = User.objects.create_user('example', 'example@example.com') + factory = APIRequestFactory() + request = factory.get('/view') + force_authenticate(request, user=user) + response = view(request) + self.assertEqual(response.data['user'], 'example') From 5427d90fa48398684948067530cd8083f785c248 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Jul 2013 17:22:11 +0100 Subject: [PATCH 11/12] Remove console style from code blocks --- docs/api-guide/testing.md | 32 ++++++++++++++++---------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index 293ee7019..a48aff00e 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -16,7 +16,7 @@ Extends [Django's existing `RequestFactory` class][requestfactory]. The `APIRequestFactory` class supports an almost identical API to Django's standard `RequestFactory` class. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available. -### Using the format arguments +#### Using the format arguments Methods which create a request body, such as `post`, `put` and `patch`, include a `format` argument, which make it easy to generate requests using a content type other than multipart form data. For example: @@ -31,7 +31,7 @@ If you need to explictly encode the request body, you can do so by explicitly se request = factory.post('/notes/', json.dumps({'title': 'new idea'}), content_type='application/json') -### PUT and PATCH with form data +#### PUT and PATCH with form data One difference worth noting between Django's `RequestFactory` and REST framework's `APIRequestFactory` is that multipart form data will be encoded for methods other than just `.post()`. @@ -105,50 +105,50 @@ To support a wider set of request formats, or change the default format, [see th ## Authenticating -### .login(**kwargs) +#### .login(**kwargs) The `login` method functions exactly as it does with Django's regular `Client` class. This allows you to authenticate requests against any views which include `SessionAuthentication`. # Make all requests in the context of a logged in session. - >>> client = APIClient() - >>> client.login(username='lauren', password='secret') + client = APIClient() + client.login(username='lauren', password='secret') To logout, call the `logout` method as usual. # Log out - >>> client.logout() + client.logout() The `login` method is appropriate for testing APIs that use session authentication, for example web sites which include AJAX interaction with the API. -### .credentials(**kwargs) +#### .credentials(**kwargs) The `credentials` method can be used to set headers that will then be included on all subsequent requests by the test client. # Include an appropriate `Authorization:` header on all requests. - >>> token = Token.objects.get(username='lauren') - >>> client = APIClient() - >>> client.credentials(HTTP_AUTHORIZATION='Token ' + token.key) + token = Token.objects.get(username='lauren') + client = APIClient() + client.credentials(HTTP_AUTHORIZATION='Token ' + token.key) Note that calling `credentials` a second time overwrites any existing credentials. You can unset any existing credentials by calling the method with no arguments. # Stop including any credentials - >>> client.credentials() + client.credentials() The `credentials` method is appropriate for testing APIs that require authentication headers, such as basic authentication, OAuth1a and OAuth2 authentication, and simple token authentication schemes. -### .force_authenticate(user=None, token=None) +#### .force_authenticate(user=None, token=None) Sometimes you may want to bypass authentication, and simple force all requests by the test client to be automatically treated as authenticated. This can be a useful shortcut if you're testing the API but don't want to have to construct valid authentication credentials in order to make test requests. - >>> user = User.objects.get(username='lauren') - >>> client = APIClient() - >>> client.force_authenticate(user=user) + user = User.objects.get(username='lauren') + client = APIClient() + client.force_authenticate(user=user) To unauthenticate subsequent requests, call `force_authenticate` setting the user and/or token to `None`. - >>> client.force_authenticate(user=None) + client.force_authenticate(user=None) ## CSRF validation From 7398464b397d37dbcfda13eb6142039fed3e9a19 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 3 Jul 2013 13:08:43 +0100 Subject: [PATCH 12/12] Tweak docs --- docs/api-guide/testing.md | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index a48aff00e..aba9283e0 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -16,10 +16,15 @@ Extends [Django's existing `RequestFactory` class][requestfactory]. The `APIRequestFactory` class supports an almost identical API to Django's standard `RequestFactory` class. This means the that standard `.get()`, `.post()`, `.put()`, `.patch()`, `.delete()`, `.head()` and `.options()` methods are all available. -#### Using the format arguments + # Using the standard RequestFactory API to create a form POST request + factory = APIRequestFactory() + request = factory.post('/notes/', {'title': 'new idea'}) + +#### Using the `format` argument Methods which create a request body, such as `post`, `put` and `patch`, include a `format` argument, which make it easy to generate requests using a content type other than multipart form data. For example: + # Create a JSON POST request factory = APIRequestFactory() request = factory.post('/notes/', {'title': 'new idea'}, format='json') @@ -27,7 +32,9 @@ By default the available formats are `'multipart'` and `'json'`. For compatibil To support a wider set of request formats, or change the default format, [see the configuration section][configuration]. -If you need to explictly encode the request body, you can do so by explicitly setting the `content_type` flag. For example: +#### Explicitly encoding the request body + +If you need to explictly encode the request body, you can do so by setting the `content_type` flag. For example: request = factory.post('/notes/', json.dumps({'title': 'new idea'}), content_type='application/json') @@ -40,7 +47,7 @@ For example, using `APIRequestFactory`, you can make a form PUT request like so: factory = APIRequestFactory() request = factory.put('/notes/547/', {'title': 'remember to email dave'}) -Using Django's `Factory`, you'd need to explicitly encode the data yourself: +Using Django's `RequestFactory`, you'd need to explicitly encode the data yourself: factory = RequestFactory() data = {'title': 'remember to email dave'}