diff --git a/rest_framework/tests/authentication.py b/rest_framework/tests/authentication.py index e86041bc4..3713e690d 100644 --- a/rest_framework/tests/authentication.py +++ b/rest_framework/tests/authentication.py @@ -1,12 +1,13 @@ from django.contrib.auth.models import User from django.http import HttpResponse -from django.test import Client, TestCase +from django.test import TestCase from rest_framework import permissions from rest_framework.authtoken.models import Token from rest_framework.authentication import TokenAuthentication from rest_framework.compat import patterns from rest_framework.views import APIView +from rest_framework.tests.utils import Client import json import base64 @@ -18,6 +19,9 @@ class MockView(APIView): def post(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + def patch(self, request): + return HttpResponse({'a': 1, 'b': 2, 'c': 3}) + def put(self, request): return HttpResponse({'a': 1, 'b': 2, 'c': 3}) @@ -102,6 +106,14 @@ class SessionAuthTests(TestCase): response = self.non_csrf_client.put('/', {'example': 'example'}) self.assertEqual(response.status_code, 200) + def test_patch_form_session_auth_passing(self): + """ + Ensure PATCHting form over session authentication with logged in user and CSRF token passes. + """ + self.non_csrf_client.login(username=self.username, password=self.password) + response = self.non_csrf_client.patch('/', {'example': 'example'}) + self.assertEqual(response.status_code, 200) + def test_post_form_session_auth_failing(self): """ Ensure POSTing form over session authentication without logged in user fails. diff --git a/rest_framework/tests/renderers.py b/rest_framework/tests/renderers.py index c1b4e624b..02ef0a3ac 100644 --- a/rest_framework/tests/renderers.py +++ b/rest_framework/tests/renderers.py @@ -111,6 +111,9 @@ class POSTDeniedView(APIView): def put(self, request): return Response() + def patch(self, request): + return Response() + class DocumentingRendererTests(TestCase): def test_only_permitted_forms_are_displayed(self): @@ -119,6 +122,7 @@ class DocumentingRendererTests(TestCase): response = view(request).render() self.assertNotContains(response, '>POST<') self.assertContains(response, '>PUT<') + self.assertContains(response, '>PATCH<') class RendererEndToEndTests(TestCase): diff --git a/rest_framework/tests/utils.py b/rest_framework/tests/utils.py index 3906adb9a..0cff55234 100644 --- a/rest_framework/tests/utils.py +++ b/rest_framework/tests/utils.py @@ -1,9 +1,9 @@ -from django.test.client import RequestFactory, FakePayload +from django.test.client import Client as _Client, RequestFactory as _RequestFactory, FakePayload from django.test.client import MULTIPART_CONTENT from urlparse import urlparse -class RequestFactory(RequestFactory): +class RequestFactory(_RequestFactory): def __init__(self, **defaults): super(RequestFactory, self).__init__(**defaults) @@ -25,3 +25,15 @@ class RequestFactory(RequestFactory): } r.update(extra) return self.request(**r) + + +class Client(_Client, RequestFactory): + def patch(self, path, data={}, content_type=MULTIPART_CONTENT, + follow=False, **extra): + """ + Send a resource to the server using PATCH. + """ + response = super(Client, self).patch(path, data=data, content_type=content_type, **extra) + if follow: + response = self._handle_redirects(response, **extra) + return response