From 08c7853655252cb9de0ade24eddfc9762a01cee7 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 1 Aug 2016 14:15:35 +0100 Subject: [PATCH 01/27] Start test case --- rest_framework/test.py | 5 +++++ tests/test_requests_client.py | 24 ++++++++++++++++++++++++ 2 files changed, 29 insertions(+) create mode 100644 tests/test_requests_client.py diff --git a/rest_framework/test.py b/rest_framework/test.py index 3ba4059a9..fb08c4a73 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -12,6 +12,7 @@ from django.test.client import ClientHandler from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode +from requests import Session from rest_framework.settings import api_settings @@ -221,6 +222,10 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient + def _pre_setup(self): + super(APITestCase, self)._pre_setup() + self.requests = Session() + class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py new file mode 100644 index 000000000..a36349a3f --- /dev/null +++ b/tests/test_requests_client.py @@ -0,0 +1,24 @@ +from __future__ import unicode_literals + +from django.conf.urls import url +from django.test import override_settings + +from rest_framework.response import Response +from rest_framework.test import APITestCase +from rest_framework.views import APIView + + +class Root(APIView): + def get(self, request): + return Response({'hello': 'world'}) + + +urlpatterns = [ + url(r'^$', Root.as_view()), +] + + +@override_settings(ROOT_URLCONF='tests.test_requests_client') +class RequestsClientTests(APITestCase): + def test_get_root(self): + print self.requests.get('http://example.com') From 3d1fff3f26835612be17b6624d766a56520880ec Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 15:42:35 +0100 Subject: [PATCH 02/27] Added 'requests' test client --- rest_framework/request.py | 2 +- rest_framework/test.py | 86 +++++++++++++++++++++++- tests/test_requests_client.py | 119 +++++++++++++++++++++++++++++++++- 3 files changed, 202 insertions(+), 5 deletions(-) diff --git a/rest_framework/request.py b/rest_framework/request.py index aafafcb32..f5738bfd5 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -373,7 +373,7 @@ class Request(object): if not _hasattr(self, '_data'): self._load_data_and_files() if is_form_media_type(self.content_type): - return self.data + return self._data return QueryDict('', encoding=self._request._encoding) @property diff --git a/rest_framework/test.py b/rest_framework/test.py index fb08c4a73..1fd530a0c 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -4,7 +4,10 @@ # to make it harder for the user to import the wrong thing without realizing. from __future__ import unicode_literals +import io + from django.conf import settings +from django.core.handlers.wsgi import WSGIHandler from django.test import testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory @@ -13,6 +16,10 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode from requests import Session +from requests.adapters import BaseAdapter +from requests.models import Response +from requests.structures import CaseInsensitiveDict +from requests.utils import get_encoding_from_headers from rest_framework.settings import api_settings @@ -22,6 +29,83 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token +class DjangoTestAdapter(BaseAdapter): + """ + A transport adaptor for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests ovet the network. + """ + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() + + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} + + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] + + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key] = value + + return self.factory.generic(method, url, **kwargs).environ + + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = Response() + + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = CaseInsensitiveDict(headers) + response.encoding = get_encoding_from_headers(response.headers) + + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) + + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) + + return response + + def close(self): + pass + + +class DjangoTestSession(Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) + + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) + + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -224,7 +308,7 @@ class APITestCase(testcases.TestCase): def _pre_setup(self): super(APITestCase, self)._pre_setup() - self.requests = Session() + self.requests = DjangoTestSession() class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index a36349a3f..0687dd92e 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -10,15 +10,128 @@ from rest_framework.views import APIView class Root(APIView): def get(self, request): - return Response({'hello': 'world'}) + return Response({ + 'method': request.method, + 'query_params': request.query_params, + }) + + def post(self, request): + files = { + key: (value.name, value.read()) + for key, value in request.FILES.items() + } + post = request.POST + json = None + if request.META.get('CONTENT_TYPE') == 'application/json': + json = request.data + + return Response({ + 'method': request.method, + 'query_params': request.query_params, + 'POST': post, + 'FILES': files, + 'JSON': json + }) + + +class Headers(APIView): + def get(self, request): + headers = { + key[5:]: value + for key, value in request.META.items() + if key.startswith('HTTP_') + } + return Response({ + 'method': request.method, + 'headers': headers + }) urlpatterns = [ url(r'^$', Root.as_view()), + url(r'^headers/$', Headers.as_view()), ] @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): - def test_get_root(self): - print self.requests.get('http://example.com') + def test_get_request(self): + response = self.requests.get('/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {} + } + assert response.json() == expected + + def test_get_request_query_params_in_url(self): + response = self.requests.get('/?key=value') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_request_query_params_by_kwarg(self): + response = self.requests.get('/', params={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'GET', + 'query_params': {'key': 'value'} + } + assert response.json() == expected + + def test_get_with_headers(self): + response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + headers = response.json()['headers'] + assert headers['USER-AGENT'] == 'example' + + def test_post_form_request(self): + response = self.requests.post('/', data={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {'key': 'value'}, + 'FILES': {}, + 'JSON': None + } + assert response.json() == expected + + def test_post_json_request(self): + response = self.requests.post('/', json={'key': 'value'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'POST': {}, + 'FILES': {}, + 'JSON': {'key': 'value'} + } + assert response.json() == expected + + def test_post_multipart_request(self): + files = { + 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') + } + response = self.requests.post('/', files=files) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'method': 'POST', + 'query_params': {}, + 'FILES': {'file': ['report.csv', 'some,data,to,send\nanother,row,to,send\n']}, + 'POST': {}, + 'JSON': None + } + assert response.json() == expected + + # cookies/session auth From e76ca6eb8838148ccb1c25a2a8a735a42f644d99 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Mon, 15 Aug 2016 16:06:04 +0100 Subject: [PATCH 03/27] Address typos --- rest_framework/test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index 1fd530a0c..e1d8eff82 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -31,8 +31,8 @@ def force_authenticate(request, user=None, token=None): class DjangoTestAdapter(BaseAdapter): """ - A transport adaptor for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests ovet the network. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ def __init__(self): self.app = WSGIHandler() @@ -55,9 +55,9 @@ class DjangoTestAdapter(BaseAdapter): # Set request headers. for key, value in request.headers.items(): key = key.upper() - if key in ('CONNECTION', 'CONTENT_LENGTH', 'CONTENT-TYPE'): + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): continue - kwargs['HTTP_%s' % key] = value + kwargs['HTTP_%s' % key.replace('-', '_')] = value return self.factory.generic(method, url, **kwargs).environ From 6ede654315e415362f4c7c8e38a3f641039bfc46 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 12:11:01 +0100 Subject: [PATCH 04/27] Graceful fallback if requests is not installed. --- rest_framework/compat.py | 7 ++ rest_framework/test.py | 133 +++++++++++++++++----------------- tests/test_requests_client.py | 6 +- 3 files changed, 78 insertions(+), 68 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84..bda346fa8 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -178,6 +178,13 @@ except (ImportError, SyntaxError): uritemplate = None +# requests is optional +try: + import requests +except ImportError: + requests = None + + # Django-guardian is optional. Import only if guardian is in INSTALLED_APPS # Fixes (#1712). We keep the try/except for the test suite. guardian = None diff --git a/rest_framework/test.py b/rest_framework/test.py index e1d8eff82..eba4b96cf 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -15,12 +15,8 @@ from django.test.client import ClientHandler from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode -from requests import Session -from requests.adapters import BaseAdapter -from requests.models import Response -from requests.structures import CaseInsensitiveDict -from requests.utils import get_encoding_from_headers +from rest_framework.compat import requests from rest_framework.settings import api_settings @@ -29,81 +25,81 @@ def force_authenticate(request, user=None, token=None): request._force_auth_token = token -class DjangoTestAdapter(BaseAdapter): - """ - A transport adapter for `requests`, that makes requests via the - Django WSGI app, rather than making actual HTTP requests over the network. - """ - def __init__(self): - self.app = WSGIHandler() - self.factory = DjangoRequestFactory() - - def get_environ(self, request): +if requests is not None: + class DjangoTestAdapter(requests.adapters.BaseAdapter): """ - Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + A transport adapter for `requests`, that makes requests via the + Django WSGI app, rather than making actual HTTP requests over the network. """ - method = request.method - url = request.url - kwargs = {} + def __init__(self): + self.app = WSGIHandler() + self.factory = DjangoRequestFactory() - # Set request content, if any exists. - if request.body is not None: - kwargs['data'] = request.body - if 'content-type' in request.headers: - kwargs['content_type'] = request.headers['content-type'] + def get_environ(self, request): + """ + Given a `requests.PreparedRequest` instance, return a WSGI environ dict. + """ + method = request.method + url = request.url + kwargs = {} - # Set request headers. - for key, value in request.headers.items(): - key = key.upper() - if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): - continue - kwargs['HTTP_%s' % key.replace('-', '_')] = value + # Set request content, if any exists. + if request.body is not None: + kwargs['data'] = request.body + if 'content-type' in request.headers: + kwargs['content_type'] = request.headers['content-type'] - return self.factory.generic(method, url, **kwargs).environ + # Set request headers. + for key, value in request.headers.items(): + key = key.upper() + if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'): + continue + kwargs['HTTP_%s' % key.replace('-', '_')] = value - def send(self, request, *args, **kwargs): - """ - Make an outgoing request to the Django WSGI application. - """ - response = Response() + return self.factory.generic(method, url, **kwargs).environ - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = CaseInsensitiveDict(headers) - response.encoding = get_encoding_from_headers(response.headers) + def send(self, request, *args, **kwargs): + """ + Make an outgoing request to the Django WSGI application. + """ + response = requests.models.Response() - environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) + def start_response(status, headers): + status_code, _, reason_phrase = status.partition(' ') + response.status_code = int(status_code) + response.reason = reason_phrase + response.headers = requests.structures.CaseInsensitiveDict(headers) + response.encoding = requests.utils.get_encoding_from_headers(response.headers) - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) + environ = self.get_environ(request) + raw_bytes = self.app(environ, start_response) - return response + response.request = request + response.url = request.url + response.raw = io.BytesIO(b''.join(raw_bytes)) - def close(self): - pass + return response + def close(self): + pass -class DjangoTestSession(Session): - def __init__(self, *args, **kwargs): - super(DjangoTestSession, self).__init__(*args, **kwargs) + class DjangoTestSession(requests.Session): + def __init__(self, *args, **kwargs): + super(DjangoTestSession, self).__init__(*args, **kwargs) - adapter = DjangoTestAdapter() - hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] + adapter = DjangoTestAdapter() + hostnames = list(settings.ALLOWED_HOSTS) + ['testserver'] - for hostname in hostnames: - if hostname == '*': - hostname = '' - self.mount('http://%s' % hostname, adapter) - self.mount('https://%s' % hostname, adapter) + for hostname in hostnames: + if hostname == '*': + hostname = '' + self.mount('http://%s' % hostname, adapter) + self.mount('https://%s' % hostname, adapter) - def request(self, method, url, *args, **kwargs): - if ':' not in url: - url = 'http://testserver/' + url.lstrip('/') - return super(DjangoTestSession, self).request(method, url, *args, **kwargs) + def request(self, method, url, *args, **kwargs): + if ':' not in url: + url = 'http://testserver/' + url.lstrip('/') + return super(DjangoTestSession, self).request(method, url, *args, **kwargs) class APIRequestFactory(DjangoRequestFactory): @@ -306,9 +302,12 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - def _pre_setup(self): - super(APITestCase, self)._pre_setup() - self.requests = DjangoTestSession() + @property + def requests(self): + if not hasattr(self, '_requests'): + assert requests is not None, 'requests must be installed' + self._requests = DjangoTestSession() + return self._requests class APISimpleTestCase(testcases.SimpleTestCase): diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 0687dd92e..24e29d3b8 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -1,8 +1,11 @@ from __future__ import unicode_literals +import unittest + from django.conf.urls import url from django.test import override_settings +from rest_framework.compat import requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -37,7 +40,7 @@ class Root(APIView): class Headers(APIView): def get(self, request): headers = { - key[5:]: value + key[5:].replace('_', '-'): value for key, value in request.META.items() if key.startswith('HTTP_') } @@ -53,6 +56,7 @@ urlpatterns = [ ] +@unittest.skipUnless(requests, 'requests not installed') @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): From 049a39e060ab8bbd028ffaa0fb2f5104a71f400d Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 15:43:12 +0100 Subject: [PATCH 05/27] Add cookie support --- rest_framework/test.py | 41 +++++++++++++------ tests/test_requests_client.py | 76 ++++++++++++++++++++++++++++++++++- 2 files changed, 102 insertions(+), 15 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index eba4b96cf..bc8ecc5db 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,7 +26,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: - class DjangoTestAdapter(requests.adapters.BaseAdapter): + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the Django WSGI app, rather than making actual HTTP requests over the network. @@ -62,23 +62,38 @@ if requests is not None: """ Make an outgoing request to the Django WSGI application. """ - response = requests.models.Response() + raw_kwargs = {} - def start_response(status, headers): - status_code, _, reason_phrase = status.partition(' ') - response.status_code = int(status_code) - response.reason = reason_phrase - response.headers = requests.structures.CaseInsensitiveDict(headers) - response.encoding = requests.utils.get_encoding_from_headers(response.headers) + def start_response(wsgi_status, wsgi_headers): + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) + self.closed = False + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + + status, _, reason = wsgi_status.partition(' ') + raw_kwargs['status'] = int(status) + raw_kwargs['reason'] = reason + raw_kwargs['headers'] = wsgi_headers + raw_kwargs['version'] = 11 + raw_kwargs['preload_content'] = False + raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers) + + # Make the outgoing request via WSGI. environ = self.get_environ(request) - raw_bytes = self.app(environ, start_response) + wsgi_response = self.app(environ, start_response) - response.request = request - response.url = request.url - response.raw = io.BytesIO(b''.join(raw_bytes)) + # Build the underlying urllib3.HTTPResponse + raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response)) + raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs) - return response + # Build the requests.Response + return self.build_response(request, raw) def close(self): pass diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 24e29d3b8..10158efa7 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -37,7 +37,7 @@ class Root(APIView): }) -class Headers(APIView): +class HeadersView(APIView): def get(self, request): headers = { key[5:].replace('_', '-'): value @@ -50,9 +50,32 @@ class Headers(APIView): }) +class SessionView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.session.items() + }) + + def post(self, request): + for key, value in request.data.items(): + request.session[key] = value + return Response({ + key: value for key, value in request.session.items() + }) + + +class CookiesView(APIView): + def get(self, request): + return Response({ + key: value for key, value in request.COOKIES.items() + }) + + urlpatterns = [ url(r'^$', Root.as_view()), - url(r'^headers/$', Headers.as_view()), + url(r'^headers/$', HeadersView.as_view()), + url(r'^session/$', SessionView.as_view()), + url(r'^cookies/$', CookiesView.as_view()), ] @@ -138,4 +161,53 @@ class RequestsClientTests(APITestCase): } assert response.json() == expected + def test_session(self): + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {} + assert response.json() == expected + + response = self.requests.post('/session/', json={'example': 'abc'}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + response = self.requests.get('/session/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'example': 'abc'} + assert response.json() == expected + + def test_cookies(self): + """ + Test for explicitly setting a cookie. + """ + my_cookie = { + "version": 0, + "name": 'COOKIE_NAME', + "value": 'COOKIE_VALUE', + "port": None, + # "port_specified":False, + "domain": 'testserver.local', + # "domain_specified":False, + # "domain_initial_dot":False, + "path": '/', + # "path_specified":True, + "secure": False, + "expires": None, + "discard": True, + "comment": None, + "comment_url": None, + "rest": {}, + "rfc2109": False + } + self.requests.cookies.set(**my_cookie) + response = self.requests.get('/cookies/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + assert response.json() == expected + # cookies/session auth From 64e19c738fce463df6fafd8f377ecc702f068813 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 17:54:03 +0100 Subject: [PATCH 06/27] Tests for auth and CSRF --- tests/test_requests_client.py | 87 ++++++++++++++++++++++------------- 1 file changed, 56 insertions(+), 31 deletions(-) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 10158efa7..aa99a71da 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -3,9 +3,14 @@ from __future__ import unicode_literals import unittest from django.conf.urls import url +from django.contrib.auth import authenticate, login +from django.contrib.auth.models import User +from django.shortcuts import redirect from django.test import override_settings +from django.utils.decorators import method_decorator +from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie -from rest_framework.compat import requests +from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response from rest_framework.test import APITestCase from rest_framework.views import APIView @@ -64,18 +69,33 @@ class SessionView(APIView): }) -class CookiesView(APIView): +class AuthView(APIView): + @method_decorator(ensure_csrf_cookie) def get(self, request): + if is_authenticated(request.user): + username = request.user.username + else: + username = None return Response({ - key: value for key, value in request.COOKIES.items() + 'username': username }) + @method_decorator(csrf_protect) + def post(self, request): + username = request.data['username'] + password = request.data['password'] + user = authenticate(username=username, password=password) + if user is None: + return Response({'error': 'incorrect credentials'}) + login(request, user) + return redirect('/auth/') + urlpatterns = [ url(r'^$', Root.as_view()), url(r'^headers/$', HeadersView.as_view()), url(r'^session/$', SessionView.as_view()), - url(r'^cookies/$', CookiesView.as_view()), + url(r'^auth/$', AuthView.as_view()), ] @@ -180,34 +200,39 @@ class RequestsClientTests(APITestCase): expected = {'example': 'abc'} assert response.json() == expected - def test_cookies(self): - """ - Test for explicitly setting a cookie. - """ - my_cookie = { - "version": 0, - "name": 'COOKIE_NAME', - "value": 'COOKIE_VALUE', - "port": None, - # "port_specified":False, - "domain": 'testserver.local', - # "domain_specified":False, - # "domain_initial_dot":False, - "path": '/', - # "path_specified":True, - "secure": False, - "expires": None, - "discard": True, - "comment": None, - "comment_url": None, - "rest": {}, - "rfc2109": False - } - self.requests.cookies.set(**my_cookie) - response = self.requests.get('/cookies/') + def test_auth(self): + # Confirm session is not authenticated + response = self.requests.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' - expected = {'COOKIE_NAME': 'COOKIE_VALUE'} + expected = { + 'username': None + } + assert response.json() == expected + assert 'csrftoken' in response.cookies + csrftoken = response.cookies['csrftoken'] + + user = User.objects.create(username='tom') + user.set_password('password') + user.save() + + # Perform a login + response = self.requests.post('/auth/', json={ + 'username': 'tom', + 'password': 'password' + }, headers={'X-CSRFToken': csrftoken}) + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } assert response.json() == expected - # cookies/session auth + # Confirm session is authenticated + response = self.requests.get('/auth/') + assert response.status_code == 200 + assert response.headers['Content-Type'] == 'application/json' + expected = { + 'username': 'tom' + } + assert response.json() == expected From da47c345c09eaf4fffca9cfcd18bba7acd844e8b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:09:19 +0100 Subject: [PATCH 07/27] Py3 compat --- rest_framework/test.py | 26 +++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bc8ecc5db..a95d18537 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -26,6 +26,21 @@ def force_authenticate(request, user=None, token=None): if requests is not None: + class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): + def get_all(self, key): + return self.getheaders(self, key) + + class MockOriginalResponse(object): + def __init__(self, headers): + self.msg = HeaderDict(headers) + self.closed = False + + def isclosed(self): + return self.closed + + def close(self): + self.closed = True + class DjangoTestAdapter(requests.adapters.HTTPAdapter): """ A transport adapter for `requests`, that makes requests via the @@ -65,17 +80,6 @@ if requests is not None: raw_kwargs = {} def start_response(wsgi_status, wsgi_headers): - class MockOriginalResponse(object): - def __init__(self, headers): - self.msg = requests.packages.urllib3._collections.HTTPHeaderDict(headers) - self.closed = False - - def isclosed(self): - return self.closed - - def close(self): - self.closed = True - status, _, reason = wsgi_status.partition(' ') raw_kwargs['status'] = int(status) raw_kwargs['reason'] = reason From 53117698e042691a7045688d41edae9cc118643b Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:47:01 +0100 Subject: [PATCH 08/27] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index a95d18537..bf22ff08d 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -27,7 +27,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): - def get_all(self, key): + def get_all(self, key, default): return self.getheaders(self, key) class MockOriginalResponse(object): From 0b3db028a2a7a0a91c3111fc6febbdbcd9cbd6b5 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 17 Aug 2016 18:50:02 +0100 Subject: [PATCH 09/27] py3 compat --- rest_framework/test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index bf22ff08d..ded9d5fe9 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -28,7 +28,7 @@ def force_authenticate(request, user=None, token=None): if requests is not None: class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict): def get_all(self, key, default): - return self.getheaders(self, key) + return self.getheaders(key) class MockOriginalResponse(object): def __init__(self, headers): From 0cc3f5008fcd41d1597776c43dc57502ed2a7542 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 18 Aug 2016 15:34:19 +0100 Subject: [PATCH 10/27] Add get_requests_client --- rest_framework/test.py | 12 +++++------- tests/test_requests_client.py | 37 ++++++++++++++++++++++------------- 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/rest_framework/test.py b/rest_framework/test.py index ded9d5fe9..e17c19a43 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -121,6 +121,11 @@ if requests is not None: return super(DjangoTestSession, self).request(method, url, *args, **kwargs) +def get_requests_client(): + assert requests is not None, 'requests must be installed' + return DjangoTestSession() + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT @@ -321,13 +326,6 @@ class APITransactionTestCase(testcases.TransactionTestCase): class APITestCase(testcases.TestCase): client_class = APIClient - @property - def requests(self): - if not hasattr(self, '_requests'): - assert requests is not None, 'requests must be installed' - self._requests = DjangoTestSession() - return self._requests - class APISimpleTestCase(testcases.SimpleTestCase): client_class = APIClient diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index aa99a71da..37bde1092 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -12,7 +12,7 @@ from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie from rest_framework.compat import is_authenticated, requests from rest_framework.response import Response -from rest_framework.test import APITestCase +from rest_framework.test import APITestCase, get_requests_client from rest_framework.views import APIView @@ -103,7 +103,8 @@ urlpatterns = [ @override_settings(ROOT_URLCONF='tests.test_requests_client') class RequestsClientTests(APITestCase): def test_get_request(self): - response = self.requests.get('/') + client = get_requests_client() + response = client.get('/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -113,7 +114,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_in_url(self): - response = self.requests.get('/?key=value') + client = get_requests_client() + response = client.get('/?key=value') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -123,7 +125,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_request_query_params_by_kwarg(self): - response = self.requests.get('/', params={'key': 'value'}) + client = get_requests_client() + response = client.get('/', params={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -133,14 +136,16 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_get_with_headers(self): - response = self.requests.get('/headers/', headers={'User-Agent': 'example'}) + client = get_requests_client() + response = client.get('/headers/', headers={'User-Agent': 'example'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' headers = response.json()['headers'] assert headers['USER-AGENT'] == 'example' def test_post_form_request(self): - response = self.requests.post('/', data={'key': 'value'}) + client = get_requests_client() + response = client.post('/', data={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -153,7 +158,8 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_json_request(self): - response = self.requests.post('/', json={'key': 'value'}) + client = get_requests_client() + response = client.post('/', json={'key': 'value'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -166,10 +172,11 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_post_multipart_request(self): + client = get_requests_client() files = { 'file': ('report.csv', 'some,data,to,send\nanother,row,to,send\n') } - response = self.requests.post('/', files=files) + response = client.post('/', files=files) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -182,19 +189,20 @@ class RequestsClientTests(APITestCase): assert response.json() == expected def test_session(self): - response = self.requests.get('/session/') + client = get_requests_client() + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {} assert response.json() == expected - response = self.requests.post('/session/', json={'example': 'abc'}) + response = client.post('/session/', json={'example': 'abc'}) assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} assert response.json() == expected - response = self.requests.get('/session/') + response = client.get('/session/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = {'example': 'abc'} @@ -202,7 +210,8 @@ class RequestsClientTests(APITestCase): def test_auth(self): # Confirm session is not authenticated - response = self.requests.get('/auth/') + client = get_requests_client() + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { @@ -217,7 +226,7 @@ class RequestsClientTests(APITestCase): user.save() # Perform a login - response = self.requests.post('/auth/', json={ + response = client.post('/auth/', json={ 'username': 'tom', 'password': 'password' }, headers={'X-CSRFToken': csrftoken}) @@ -229,7 +238,7 @@ class RequestsClientTests(APITestCase): assert response.json() == expected # Confirm session is authenticated - response = self.requests.get('/auth/') + response = client.get('/auth/') assert response.status_code == 200 assert response.headers['Content-Type'] == 'application/json' expected = { From e4f692831e850a710c60a7343e76bca19e8e6e83 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Fri, 2 Sep 2016 18:04:19 +0100 Subject: [PATCH 11/27] Added SchemaGenerator.should_include_link --- rest_framework/schemas.py | 65 ++++++++++++++++++++++++--------------- 1 file changed, 40 insertions(+), 25 deletions(-) diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 1b899450f..bf1e6dd4a 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -72,31 +72,10 @@ class SchemaGenerator(object): links = [] for path, method, category, action, callback in self.endpoints: - view = callback.cls() - for attr, val in getattr(callback, 'initkwargs', {}).items(): - setattr(view, attr, val) - view.args = () - view.kwargs = {} - view.format_kwarg = None - - actions = getattr(callback, 'actions', None) - if actions is not None: - if method == 'OPTIONS': - view.action = 'metadata' - else: - view.action = actions.get(method.lower()) - - if request is not None: - view.request = clone_request(request, method) - try: - view.check_permissions(view.request) - except exceptions.APIException: - continue - else: - view.request = None - - link = self.get_link(path, method, callback, view) - links.append((category, action, link)) + view = self.setup_view(callback, method, request) + if self.should_include_link(path, method, callback, view): + link = self.get_link(path, method, callback, view) + links.append((category, action, link)) if not links: return None @@ -215,8 +194,44 @@ class SchemaGenerator(object): except IndexError: return None + def setup_view(self, callback, method, request): + """ + Setup a view instance. + """ + view = callback.cls() + for attr, val in getattr(callback, 'initkwargs', {}).items(): + setattr(view, attr, val) + view.args = () + view.kwargs = {} + view.format_kwarg = None + + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + + if request is not None: + view.request = clone_request(request, method) + else: + view.request = None + + return view + # Methods for generating each individual `Link` instance... + def should_include_link(self, path, method, callback, view): + if view.request is None: + return True + + try: + view.check_permissions(view.request) + except exceptions.APIException: + return False + + return True + def get_link(self, path, method, callback, view): """ Return a `coreapi.Link` instance for the given endpoint. From 46b9e4edf9ba7c5b10deef4515d287898cdebe38 Mon Sep 17 00:00:00 2001 From: Andy Schriner Date: Mon, 22 Aug 2016 11:38:33 -0700 Subject: [PATCH 12/27] add settings for html cutoff on related fields --- docs/api-guide/relations.md | 2 ++ docs/api-guide/settings.md | 16 ++++++++++++++++ rest_framework/relations.py | 31 +++++++++++++++++++++---------- rest_framework/settings.py | 4 ++++ 4 files changed, 43 insertions(+), 10 deletions(-) diff --git a/docs/api-guide/relations.md b/docs/api-guide/relations.md index 8695b2c1e..75643a83e 100644 --- a/docs/api-guide/relations.md +++ b/docs/api-guide/relations.md @@ -457,6 +457,8 @@ There are two keyword arguments you can use to control this behavior: - `html_cutoff` - If set this will be the maximum number of choices that will be displayed by a HTML select drop down. Set to `None` to disable any limiting. Defaults to `1000`. - `html_cutoff_text` - If set this will display a textual indicator if the maximum number of items have been cutoff in an HTML select drop down. Defaults to `"More than {count} items…"` +You can also control these globally using the settings `HTML_SELECT_CUTOFF` and `HTML_SELECT_CUTOFF_TEXT`. + In cases where the cutoff is being enforced you may want to instead use a plain input field in the HTML form. You can do so using the `style` keyword argument. For example: assigned_to = serializers.SlugRelatedField( diff --git a/docs/api-guide/settings.md b/docs/api-guide/settings.md index ea018053f..67d317839 100644 --- a/docs/api-guide/settings.md +++ b/docs/api-guide/settings.md @@ -382,6 +382,22 @@ This should be a function with the following signature: Default: `'rest_framework.views.get_view_description'` +## HTML Select Field cutoffs + +Global settings for [select field cutoffs for rendering relational fields](relations.md#select-field-cutoffs) in the browsable API. + +#### HTML_SELECT_CUTOFF + +Global setting for the `html_cutoff` value. Must be an integer. + +Default: 1000 + +#### HTML_SELECT_CUTOFF_TEXT + +A string representing a global setting for `html_cutoff_text`. + +Default: `"More than {count} items..."` + --- ## Miscellaneous settings diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 65c4c0318..7fe22ec5b 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -18,6 +18,7 @@ from rest_framework.fields import ( Field, empty, get_attribute, is_simple_callable, iter_options ) from rest_framework.reverse import reverse +from rest_framework.settings import api_settings from rest_framework.utils import html @@ -71,14 +72,19 @@ MANY_RELATION_KWARGS = ( class RelatedField(Field): queryset = None - html_cutoff = 1000 - html_cutoff_text = _('More than {count} items...') + html_cutoff = None + html_cutoff_text = None def __init__(self, **kwargs): self.queryset = kwargs.pop('queryset', self.queryset) - self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) - self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) - + self.html_cutoff = kwargs.pop( + 'html_cutoff', + self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF) + ) + self.html_cutoff_text = kwargs.pop( + 'html_cutoff_text', + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + ) if not method_overridden('get_queryset', RelatedField, self): assert self.queryset is not None or kwargs.get('read_only', None), ( 'Relational field must provide a `queryset` argument, ' @@ -447,15 +453,20 @@ class ManyRelatedField(Field): 'not_a_list': _('Expected a list of items but got type "{input_type}".'), 'empty': _('This list may not be empty.') } - html_cutoff = 1000 - html_cutoff_text = _('More than {count} items...') + html_cutoff = None + html_cutoff_text = None def __init__(self, child_relation=None, *args, **kwargs): self.child_relation = child_relation self.allow_empty = kwargs.pop('allow_empty', True) - self.html_cutoff = kwargs.pop('html_cutoff', self.html_cutoff) - self.html_cutoff_text = kwargs.pop('html_cutoff_text', self.html_cutoff_text) - + self.html_cutoff = kwargs.pop( + 'html_cutoff', + self.html_cutoff or int(api_settings.HTML_SELECT_CUTOFF) + ) + self.html_cutoff_text = kwargs.pop( + 'html_cutoff_text', + self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT) + ) assert child_relation is not None, '`child_relation` is a required argument.' super(ManyRelatedField, self).__init__(*args, **kwargs) self.child_relation.bind(field_name='', parent=self) diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 68c7709e8..89e27e743 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -111,6 +111,10 @@ DEFAULTS = { 'COMPACT_JSON': True, 'COERCE_DECIMAL_TO_STRING': True, 'UPLOADED_FILES_USE_URL': True, + + # Browseable API + 'HTML_SELECT_CUTOFF': 1000, + 'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...", } From a556b9cb426f2893cb6807299a94044d585b7c2c Mon Sep 17 00:00:00 2001 From: Christian Sauer Date: Wed, 14 Sep 2016 18:01:30 -0400 Subject: [PATCH 13/27] Router doesn't work if prefix is blank, though project urls.py handles prefix --- tests/test_routers.py | 49 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 49 insertions(+) diff --git a/tests/test_routers.py b/tests/test_routers.py index f45039f80..d28e301a0 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -1,5 +1,6 @@ from __future__ import unicode_literals +import json from collections import namedtuple from django.conf.urls import include, url @@ -47,6 +48,21 @@ class MockViewSet(viewsets.ModelViewSet): serializer_class = None +class EmptyPrefixSerializer(serializers.HyperlinkedModelSerializer): + class Meta: + model = RouterTestModel + fields = ('uuid', 'text') + + +class EmptyPrefixViewSet(viewsets.ModelViewSet): + queryset = [RouterTestModel(id=1, uuid='111', text='First'), RouterTestModel(id=2, uuid='222', text='Second')] + serializer_class = EmptyPrefixSerializer + + def get_object(self, *args, **kwargs): + index = int(self.kwargs['pk']) - 1 + return self.queryset[index] + + notes_router = SimpleRouter() notes_router.register(r'notes', NoteViewSet) @@ -56,11 +72,19 @@ kwarged_notes_router.register(r'notes', KWargedNoteViewSet) namespaced_router = DefaultRouter() namespaced_router.register(r'example', MockViewSet, base_name='example') +empty_prefix_router = SimpleRouter() +empty_prefix_router.register(r'', EmptyPrefixViewSet, base_name='empty_prefix') +empty_prefix_urls = [ + url(r'^', include(empty_prefix_router.urls)), +] + urlpatterns = [ url(r'^non-namespaced/', include(namespaced_router.urls)), url(r'^namespaced/', include(namespaced_router.urls, namespace='example')), url(r'^example/', include(notes_router.urls)), url(r'^example2/', include(kwarged_notes_router.urls)), + + url(r'^empty-prefix/', include(empty_prefix_urls)), ] @@ -384,3 +408,28 @@ class TestDynamicListAndDetailRouter(TestCase): def test_inherited_list_and_detail_route_decorators(self): self._test_list_and_detail_route_decorators(SubDynamicListAndDetailViewSet) + + +@override_settings(ROOT_URLCONF='tests.test_routers') +class TestEmptyPrefix(TestCase): + def test_empty_prefix_list(self): + response = self.client.get('/empty-prefix/') + self.assertEqual(200, response.status_code) + self.assertEqual( + json.loads(response.content.decode('utf-8')), + [ + {'uuid': '111', 'text': 'First'}, + {'uuid': '222', 'text': 'Second'} + ] + ) + + def test_empty_prefix_detail(self): + response = self.client.get('/empty-prefix/1/') + self.assertEqual(200, response.status_code) + self.assertEqual( + json.loads(response.content.decode('utf-8')), + { + 'uuid': '111', + 'text': 'First' + } + ) From 8609c9ca8c0d9c747d594d2b4f5cc7cc8be98f0a Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:44:48 -0400 Subject: [PATCH 14/27] Fix Django 1.10 to-many deprecation --- rest_framework/compat.py | 8 ++++++++ rest_framework/serializers.py | 10 +++++++--- tests/test_model_serializer.py | 4 ++-- tests/test_permissions.py | 10 +++++----- 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index cee430a84..9ee40b257 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -277,3 +277,11 @@ def template_render(template, context=None, request=None): # backends template, e.g. django.template.backends.django.Template else: return template.render(context, request=request) + + +def set_many(instance, field, value): + if django.VERSION < (1, 10): + setattr(instance, field, value) + else: + field = getattr(instance, field) + field.set(value) diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 4d1ed63ae..28f70bd40 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -23,7 +23,7 @@ from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ from rest_framework.compat import JSONField as ModelJSONField -from rest_framework.compat import postgres_fields, unicode_to_repr +from rest_framework.compat import postgres_fields, set_many, unicode_to_repr from rest_framework.utils import model_meta from rest_framework.utils.field_mapping import ( ClassLookupDict, get_field_kwargs, get_nested_relation_kwargs, @@ -892,19 +892,23 @@ class ModelSerializer(Serializer): # Save many-to-many relationships after the instance is created. if many_to_many: for field_name, value in many_to_many.items(): - setattr(instance, field_name, value) + set_many(instance, field_name, value) return instance def update(self, instance, validated_data): raise_errors_on_nested_writes('update', self, validated_data) + info = model_meta.get_field_info(instance) # Simply set each attribute on the instance, and then save it. # Note that unlike `.create()` we don't need to treat many-to-many # relationships as being a special case. During updates we already # have an instance pk for the relationships to be associated with. for attr, value in validated_data.items(): - setattr(instance, attr, value) + if attr in info.relations and info.relations[attr].to_many: + set_many(instance, attr, value) + else: + setattr(instance, attr, value) instance.save() return instance diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 01243ff6e..5dac16f2b 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -20,7 +20,7 @@ from django.test import TestCase from django.utils import six from rest_framework import serializers -from rest_framework.compat import unicode_repr +from rest_framework.compat import set_many, unicode_repr def dedent(blocktext): @@ -651,7 +651,7 @@ class TestIntegration(TestCase): foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, ) - self.instance.many_to_many = self.many_to_many_targets + set_many(self.instance, 'many_to_many', self.many_to_many_targets) self.instance.save() def test_pk_retrival(self): diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 5cef22628..0445f27ca 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -12,7 +12,7 @@ from rest_framework import ( HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status ) -from rest_framework.compat import guardian +from rest_framework.compat import guardian, set_many from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory @@ -74,15 +74,15 @@ class ModelPermissionsIntegrationTests(TestCase): def setUp(self): User.objects.create_user('disallowed', 'disallowed@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password') - user.user_permissions = [ + set_many(user, 'user_permissions', [ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') - ] + ]) user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') - user.user_permissions = [ + set_many(user, 'user_permissions', [ Permission.objects.get(codename='change_basicmodel'), - ] + ]) self.permitted_credentials = basic_auth_header('permitted', 'password') self.disallowed_credentials = basic_auth_header('disallowed', 'password') From 197b63ab85edb0bb1f02a5bc05ce87b1e4d541a5 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:51:51 -0400 Subject: [PATCH 15/27] Add django.core.urlresolvers compatibility --- docs/api-guide/reverse.md | 4 ++-- docs/api-guide/testing.md | 2 +- rest_framework/compat.py | 10 ++++++++++ rest_framework/relations.py | 6 +++--- rest_framework/reverse.py | 6 +++--- rest_framework/routers.py | 2 +- rest_framework/schemas.py | 5 +++-- rest_framework/templatetags/rest_framework.py | 3 +-- rest_framework/urlpatterns.py | 2 +- rest_framework/utils/breadcrumbs.py | 2 +- tests/test_filters.py | 3 +-- tests/test_permissions.py | 3 +-- tests/test_reverse.py | 2 +- tests/test_urlpatterns.py | 8 ++++---- tests/utils.py | 2 +- 15 files changed, 34 insertions(+), 26 deletions(-) diff --git a/docs/api-guide/reverse.md b/docs/api-guide/reverse.md index 71fb83f9e..35d88e2db 100644 --- a/docs/api-guide/reverse.md +++ b/docs/api-guide/reverse.md @@ -23,7 +23,7 @@ There's no requirement for you to use them, but if you do then the self-describi **Signature:** `reverse(viewname, *args, **kwargs)` -Has the same behavior as [`django.core.urlresolvers.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port. +Has the same behavior as [`django.urls.reverse`][reverse], except that it returns a fully qualified URL, using the request to determine the host and port. You should **include the request as a keyword argument** to the function, for example: @@ -44,7 +44,7 @@ You should **include the request as a keyword argument** to the function, for ex **Signature:** `reverse_lazy(viewname, *args, **kwargs)` -Has the same behavior as [`django.core.urlresolvers.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port. +Has the same behavior as [`django.urls.reverse_lazy`][reverse-lazy], except that it returns a fully qualified URL, using the request to determine the host and port. As with the `reverse` function, you should **include the request as a keyword argument** to the function, for example: diff --git a/docs/api-guide/testing.md b/docs/api-guide/testing.md index 69da7d105..18f9e19e9 100644 --- a/docs/api-guide/testing.md +++ b/docs/api-guide/testing.md @@ -197,7 +197,7 @@ REST framework includes the following test case classes, that mirror the existin You can use any of REST framework's test case classes as you would for the regular Django test case classes. The `self.client` attribute will be an `APIClient` instance. - from django.core.urlresolvers import reverse + from django.urls import reverse from rest_framework import status from rest_framework.test import APITestCase from myproject.apps.core.models import Account diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 9ee40b257..958c481cd 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -23,6 +23,16 @@ except ImportError: from django.utils import importlib # Will be removed in Django 1.9 +try: + from django.urls import ( + NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve + ) +except ImportError: + from django.core.urlresolvers import ( # Will be removed in Django 2.0 + NoReverseMatch, RegexURLPattern, RegexURLResolver, ResolverMatch, Resolver404, get_script_prefix, reverse, reverse_lazy, resolve + ) + + try: import urlparse # Python 2.x except ImportError: diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 65c4c0318..4317d1199 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -4,9 +4,6 @@ from __future__ import unicode_literals from collections import OrderedDict from django.core.exceptions import ImproperlyConfigured, ObjectDoesNotExist -from django.core.urlresolvers import ( - NoReverseMatch, Resolver404, get_script_prefix, resolve -) from django.db.models import Manager from django.db.models.query import QuerySet from django.utils import six @@ -14,6 +11,9 @@ from django.utils.encoding import python_2_unicode_compatible, smart_text from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ +from rest_framework.compat import ( + NoReverseMatch, Resolver404, get_script_prefix, resolve +) from rest_framework.fields import ( Field, empty, get_attribute, is_simple_callable, iter_options ) diff --git a/rest_framework/reverse.py b/rest_framework/reverse.py index 5a7ba09a8..fd418dcca 100644 --- a/rest_framework/reverse.py +++ b/rest_framework/reverse.py @@ -3,11 +3,11 @@ Provide urlresolver functions that return fully qualified URLs or view names """ from __future__ import unicode_literals -from django.core.urlresolvers import reverse as django_reverse -from django.core.urlresolvers import NoReverseMatch from django.utils import six from django.utils.functional import lazy +from rest_framework.compat import reverse as django_reverse +from rest_framework.compat import NoReverseMatch from rest_framework.settings import api_settings from rest_framework.utils.urls import replace_query_param @@ -54,7 +54,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra): """ - Same as `django.core.urlresolvers.reverse`, but optionally takes a request + Same as `django.urls.reverse`, but optionally takes a request and returns a fully qualified URL, using the request to get the base URL. """ if format is not None: diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..c6516b06b 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -20,9 +20,9 @@ from collections import OrderedDict, namedtuple from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured -from django.core.urlresolvers import NoReverseMatch from rest_framework import exceptions, renderers, views +from rest_framework.compat import NoReverseMatch from rest_framework.response import Response from rest_framework.reverse import reverse from rest_framework.schemas import SchemaGenerator diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 1b899450f..72cd8017a 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -2,12 +2,13 @@ from importlib import import_module from django.conf import settings from django.contrib.admindocs.views import simplify_regex -from django.core.urlresolvers import RegexURLPattern, RegexURLResolver from django.utils import six from django.utils.encoding import force_text from rest_framework import exceptions, serializers -from rest_framework.compat import coreapi, uritemplate, urlparse +from rest_framework.compat import ( + RegexURLPattern, RegexURLResolver, coreapi, uritemplate, urlparse +) from rest_framework.request import clone_request from rest_framework.views import APIView diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 3bb85e472..c1c8a5396 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -3,14 +3,13 @@ from __future__ import absolute_import, unicode_literals import re from django import template -from django.core.urlresolvers import NoReverseMatch, reverse from django.template import loader from django.utils import six from django.utils.encoding import force_text, iri_to_uri from django.utils.html import escape, format_html, smart_urlquote from django.utils.safestring import SafeData, mark_safe -from rest_framework.compat import template_render +from rest_framework.compat import NoReverseMatch, reverse, template_render from rest_framework.renderers import HTMLFormRenderer from rest_framework.utils.urls import replace_query_param diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 7a02bb0f0..4ea55300e 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,8 +1,8 @@ from __future__ import unicode_literals from django.conf.urls import include, url -from django.core.urlresolvers import RegexURLResolver +from rest_framework.compat import RegexURLResolver from rest_framework.settings import api_settings diff --git a/rest_framework/utils/breadcrumbs.py b/rest_framework/utils/breadcrumbs.py index 2e3ab9084..74f4f7840 100644 --- a/rest_framework/utils/breadcrumbs.py +++ b/rest_framework/utils/breadcrumbs.py @@ -1,6 +1,6 @@ from __future__ import unicode_literals -from django.core.urlresolvers import get_script_prefix, resolve +from rest_framework.compat import get_script_prefix, resolve def get_breadcrumbs(url, request=None): diff --git a/tests/test_filters.py b/tests/test_filters.py index 03d61fc37..fdb3c1c0b 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -6,7 +6,6 @@ from decimal import Decimal from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured -from django.core.urlresolvers import reverse from django.db import models from django.test import TestCase from django.test.utils import override_settings @@ -14,7 +13,7 @@ from django.utils.dateparse import parse_date from django.utils.six.moves import reload_module from rest_framework import filters, generics, serializers, status -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, reverse from rest_framework.test import APIRequestFactory from .models import BaseFilterableItem, BasicModel, FilterableItem diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 0445f27ca..f8561e61d 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -4,7 +4,6 @@ import base64 import unittest from django.contrib.auth.models import Group, Permission, User -from django.core.urlresolvers import ResolverMatch from django.db import models from django.test import TestCase @@ -12,7 +11,7 @@ from rest_framework import ( HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status ) -from rest_framework.compat import guardian, set_many +from rest_framework.compat import ResolverMatch, guardian, set_many from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory diff --git a/tests/test_reverse.py b/tests/test_reverse.py index 03d31f1f9..f30a8bf9a 100644 --- a/tests/test_reverse.py +++ b/tests/test_reverse.py @@ -1,9 +1,9 @@ from __future__ import unicode_literals from django.conf.urls import url -from django.core.urlresolvers import NoReverseMatch from django.test import TestCase, override_settings +from rest_framework.compat import NoReverseMatch from rest_framework.reverse import reverse from rest_framework.test import APIRequestFactory diff --git a/tests/test_urlpatterns.py b/tests/test_urlpatterns.py index 78d37c1a8..33d367e1d 100644 --- a/tests/test_urlpatterns.py +++ b/tests/test_urlpatterns.py @@ -3,9 +3,9 @@ from __future__ import unicode_literals from collections import namedtuple from django.conf.urls import include, url -from django.core import urlresolvers from django.test import TestCase +from rest_framework.compat import RegexURLResolver, Resolver404 from rest_framework.test import APIRequestFactory from rest_framework.urlpatterns import format_suffix_patterns @@ -28,7 +28,7 @@ class FormatSuffixTests(TestCase): urlpatterns = format_suffix_patterns(urlpatterns) except Exception: self.fail("Failed to apply `format_suffix_patterns` on the supplied urlpatterns") - resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + resolver = RegexURLResolver(r'^/', urlpatterns) for test_path in test_paths: request = factory.get(test_path.path) try: @@ -43,7 +43,7 @@ class FormatSuffixTests(TestCase): urlpatterns = format_suffix_patterns([ url(r'^test/$', dummy_view), ]) - resolver = urlresolvers.RegexURLResolver(r'^/', urlpatterns) + resolver = RegexURLResolver(r'^/', urlpatterns) test_paths = [ (URLTestPath('/test.api', (), {'format': 'api'}), True), @@ -55,7 +55,7 @@ class FormatSuffixTests(TestCase): request = factory.get(test_path.path) try: callback, callback_args, callback_kwargs = resolver.resolve(request.path_info) - except urlresolvers.Resolver404: + except Resolver404: callback, callback_args, callback_kwargs = (None, None, None) if not expected_resolved: assert callback is None diff --git a/tests/utils.py b/tests/utils.py index 5b2d75864..52582f093 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,5 +1,5 @@ from django.core.exceptions import ObjectDoesNotExist -from django.core.urlresolvers import NoReverseMatch +from rest_framework.compat import NoReverseMatch class MockObject(object): From bb37cb79929e145714c0018d90aed93e1d58d5a4 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:53:14 -0400 Subject: [PATCH 16/27] Update django-filter & django-guardian --- requirements/requirements-optionals.txt | 4 ++-- tests/test_filters.py | 1 + 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 20436e6b4..87e5034cd 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -1,5 +1,5 @@ # Optional packages which may be used with REST framework. markdown==2.6.4 -django-guardian==1.4.3 -django-filter==0.13.0 +django-guardian==1.4.6 +django-filter==0.14.0 coreapi==1.32.0 diff --git a/tests/test_filters.py b/tests/test_filters.py index fdb3c1c0b..d696309c1 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -76,6 +76,7 @@ if django_filters: class Meta: model = BaseFilterableItem + fields = '__all__' class BaseFilterableItemFilterRootView(generics.ListCreateAPIView): queryset = FilterableItem.objects.all() From a372a8edea70260ebfb559f3d0bd692132cdfb40 Mon Sep 17 00:00:00 2001 From: Christian Sauer Date: Thu, 15 Sep 2016 12:54:35 -0400 Subject: [PATCH 17/27] Check for empty router prefix; adjust URL accordingly It's easiest to fix this issue after we have made the regex. To try to fix it before would require doing something different for List vs Detail, which means we'd have to know which type of url we're constructing before acting accordingly. --- rest_framework/routers.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..e2fbfa77b 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -83,6 +83,7 @@ class BaseRouter(object): class SimpleRouter(BaseRouter): + routes = [ # List route. Route( @@ -258,6 +259,13 @@ class SimpleRouter(BaseRouter): trailing_slash=self.trailing_slash ) + # If there is no prefix, the first part of the url is probably + # controlled by project's urls.py and the router is in an app, + # so a slash in the beginning will (A) cause Django to give + # warnings and (B) generate URLS that will require using '//'. + if not prefix and regex[:2] == '^/': + regex = '^' + regex[2:] + view = viewset.as_view(mapping, **route.initkwargs) name = route.name.format(basename=basename) ret.append(url(regex, view, name=name)) From a084924ced2ee4cb18d0f8c48b603f71ce6d2429 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:54:47 -0400 Subject: [PATCH 18/27] Fix misc django deprecations --- rest_framework/compat.py | 6 ++++ tests/browsable_api/test_form_rendering.py | 1 + tests/conftest.py | 15 ++++++---- tests/test_atomic_requests.py | 32 +++++++++++----------- tests/test_authentication.py | 3 +- tests/test_filters.py | 2 +- tests/test_model_serializer.py | 3 +- tests/test_request.py | 5 ++-- 8 files changed, 40 insertions(+), 27 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 958c481cd..1a94f22b6 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -138,6 +138,12 @@ def is_authenticated(user): return user.is_authenticated +def is_anonymous(user): + if django.VERSION < (1, 10): + return user.is_anonymous() + return user.is_anonymous + + def get_related_model(field): if django.VERSION < (1, 9): return _resolve_model(field.rel.to) diff --git a/tests/browsable_api/test_form_rendering.py b/tests/browsable_api/test_form_rendering.py index 5a31ae0dd..8b79ab6ff 100644 --- a/tests/browsable_api/test_form_rendering.py +++ b/tests/browsable_api/test_form_rendering.py @@ -11,6 +11,7 @@ factory = APIRequestFactory() class BasicSerializer(serializers.ModelSerializer): class Meta: model = BasicModel + fields = '__all__' class ManyPostView(generics.GenericAPIView): diff --git a/tests/conftest.py b/tests/conftest.py index a5123b9d8..256678226 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,13 @@ def pytest_configure(): from django.conf import settings + MIDDLEWARE = ( + 'django.middleware.common.CommonMiddleware', + 'django.contrib.sessions.middleware.SessionMiddleware', + 'django.contrib.auth.middleware.AuthenticationMiddleware', + 'django.contrib.messages.middleware.MessageMiddleware', + ) + settings.configure( DEBUG_PROPAGATE_EXCEPTIONS=True, DATABASES={ @@ -21,12 +28,8 @@ def pytest_configure(): 'APP_DIRS': True, }, ], - MIDDLEWARE_CLASSES=( - 'django.middleware.common.CommonMiddleware', - 'django.contrib.sessions.middleware.SessionMiddleware', - 'django.contrib.auth.middleware.AuthenticationMiddleware', - 'django.contrib.messages.middleware.MessageMiddleware', - ), + MIDDLEWARE=MIDDLEWARE, + MIDDLEWARE_CLASSES=MIDDLEWARE, INSTALLED_APPS=( 'django.contrib.auth', 'django.contrib.contenttypes', diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index 8342ad3af..09d7f2fb1 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -5,7 +5,7 @@ import unittest from django.conf.urls import url from django.db import connection, connections, transaction from django.http import Http404 -from django.test import TestCase, TransactionTestCase +from django.test import TestCase, TransactionTestCase, override_settings from django.utils.decorators import method_decorator from rest_framework import status @@ -36,6 +36,20 @@ class APIExceptionView(APIView): raise APIException +class NonAtomicAPIExceptionView(APIView): + @method_decorator(transaction.non_atomic_requests) + def dispatch(self, *args, **kwargs): + return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs) + + def get(self, request, *args, **kwargs): + BasicModel.objects.all() + raise Http404 + +urlpatterns = ( + url(r'^$', NonAtomicAPIExceptionView.as_view()), +) + + @unittest.skipUnless( connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." @@ -124,22 +138,8 @@ class DBTransactionAPIExceptionTests(TestCase): connection.features.uses_savepoints, "'atomic' requires transactions and savepoints." ) +@override_settings(ROOT_URLCONF='tests.test_atomic_requests') class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): - @property - def urls(self): - class NonAtomicAPIExceptionView(APIView): - @method_decorator(transaction.non_atomic_requests) - def dispatch(self, *args, **kwargs): - return super(NonAtomicAPIExceptionView, self).dispatch(*args, **kwargs) - - def get(self, request, *args, **kwargs): - BasicModel.objects.all() - raise Http404 - - return ( - url(r'^$', NonAtomicAPIExceptionView.as_view()), - ) - def setUp(self): connections.databases['default']['ATOMIC_REQUESTS'] = True diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 5ef620abe..6f17ea14f 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -20,6 +20,7 @@ from rest_framework.authentication import ( ) from rest_framework.authtoken.models import Token from rest_framework.authtoken.views import obtain_auth_token +from rest_framework.compat import is_authenticated from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView @@ -408,7 +409,7 @@ class FailingAuthAccessedInRenderer(TestCase): def render(self, data, media_type=None, renderer_context=None): request = renderer_context['request'] - if request.user.is_authenticated(): + if is_authenticated(request.user): return b'authenticated' return b'not authenticated' diff --git a/tests/test_filters.py b/tests/test_filters.py index d696309c1..c67412dd7 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -456,7 +456,7 @@ class AttributeModel(models.Model): class SearchFilterModelFk(models.Model): title = models.CharField(max_length=20) - attribute = models.ForeignKey(AttributeModel) + attribute = models.ForeignKey(AttributeModel, on_delete=models.CASCADE) class SearchFilterFkSerializer(serializers.ModelSerializer): diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 5dac16f2b..cd9b2dfc3 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -962,7 +962,7 @@ class OneToOneTargetTestModel(models.Model): class OneToOneSourceTestModel(models.Model): - target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True) + target = models.OneToOneField(OneToOneTargetTestModel, primary_key=True, on_delete=models.CASCADE) class TestModelFieldValues(TestCase): @@ -990,6 +990,7 @@ class TestUniquenessOverride(TestCase): class TestSerializer(serializers.ModelSerializer): class Meta: model = TestModel + fields = '__all__' extra_kwargs = {'field_1': {'required': False}} fields = TestSerializer().fields diff --git a/tests/test_request.py b/tests/test_request.py index dbfa695fd..32fbbc50b 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -13,6 +13,7 @@ from django.utils import six from rest_framework import status from rest_framework.authentication import SessionAuthentication +from rest_framework.compat import is_anonymous from rest_framework.parsers import BaseParser, FormParser, MultiPartParser from rest_framework.request import Request from rest_framework.response import Response @@ -169,9 +170,9 @@ class TestUserSetter(TestCase): def test_user_can_logout(self): self.request.user = self.user - self.assertFalse(self.request.user.is_anonymous()) + self.assertFalse(is_anonymous(self.request.user)) logout(self.request) - self.assertTrue(self.request.user.is_anonymous()) + self.assertTrue(is_anonymous(self.request.user)) def test_logged_in_user_is_set_on_wrapped_request(self): login(self.request, self.user) From 3bfb0b716874559044e8c5bee3e575a549e057ab Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Wed, 14 Sep 2016 23:57:52 -0400 Subject: [PATCH 19/27] Use TOC extension instead of header --- rest_framework/compat.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 1a94f22b6..8afe52f54 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -216,8 +216,13 @@ try: if markdown.version <= '2.2': HEADERID_EXT_PATH = 'headerid' - else: + LEVEL_PARAM = 'level' + elif markdown.version < '2.6': HEADERID_EXT_PATH = 'markdown.extensions.headerid' + LEVEL_PARAM = 'level' + else: + HEADERID_EXT_PATH = 'markdown.extensions.toc' + LEVEL_PARAM = 'baselevel' def apply_markdown(text): """ @@ -227,7 +232,7 @@ try: extensions = [HEADERID_EXT_PATH] extension_configs = { HEADERID_EXT_PATH: { - 'level': '2' + LEVEL_PARAM: '2' } } md = markdown.Markdown( From 3bdb0e9dd8b8f098111fad4a874ee99d392e944a Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Fri, 16 Sep 2016 13:08:26 -0400 Subject: [PATCH 20/27] Fix deprecations for py3k --- rest_framework/fields.py | 12 ++++++++++++ tests/test_schemas.py | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index f76e4e801..7d677a408 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -53,6 +53,18 @@ def is_simple_callable(obj): """ True if the object is a callable that takes no arguments. """ + if not hasattr(inspect, 'signature'): + return py2k_is_simple_callable(obj) + + if not callable(obj): + return False + + sig = inspect.signature(obj) + params = sig.parameters.values() + return all(param.default != param.empty for param in params) + + +def py2k_is_simple_callable(obj): function = inspect.isfunction(obj) method = inspect.ismethod(obj) diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 197e62eb0..dc01d8cd8 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -215,4 +215,4 @@ class TestSchemaGenerator(TestCase): } } ) - self.assertEquals(schema, expected) + self.assertEqual(schema, expected) From a0a8b9890a821bb56aace86ed18e57b2a1137e07 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 22 Sep 2016 13:56:27 -0400 Subject: [PATCH 21/27] Add py3k compatibility to is_simple_callable --- rest_framework/fields.py | 38 ++++++++++++++++++++++++++------------ 1 file changed, 26 insertions(+), 12 deletions(-) diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 917a151e5..7f8391b8a 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -49,20 +49,34 @@ class empty: pass -def is_simple_callable(obj): - """ - True if the object is a callable that takes no arguments. - """ - function = inspect.isfunction(obj) - method = inspect.ismethod(obj) +if six.PY3: + def is_simple_callable(obj): + """ + True if the object is a callable that takes no arguments. + """ + if not callable(obj): + return False - if not (function or method): - return False + sig = inspect.signature(obj) + params = sig.parameters.values() + return all(param.default != param.empty for param in params) - args, _, _, defaults = inspect.getargspec(obj) - len_args = len(args) if function else len(args) - 1 - len_defaults = len(defaults) if defaults else 0 - return len_args <= len_defaults +else: + def is_simple_callable(obj): + function = inspect.isfunction(obj) + method = inspect.ismethod(obj) + + if not (function or method): + return False + + if method: + is_unbound = obj.im_self is None + + args, _, _, defaults = inspect.getargspec(obj) + + len_args = len(args) if function or is_unbound else len(args) - 1 + len_defaults = len(defaults) if defaults else 0 + return len_args <= len_defaults def get_attribute(instance, attrs): From adcf6536e75272649da1d8e5a8b5ba7926e33147 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 22 Sep 2016 13:56:37 -0400 Subject: [PATCH 22/27] Add is_simple_callable tests --- tests/test_fields.py | 62 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/test_fields.py b/tests/test_fields.py index 4a4b741c5..c271afa9e 100644 --- a/tests/test_fields.py +++ b/tests/test_fields.py @@ -1,6 +1,7 @@ import datetime import os import re +import unittest import uuid from decimal import Decimal @@ -11,6 +12,67 @@ from django.utils import six, timezone import rest_framework from rest_framework import serializers +from rest_framework.fields import is_simple_callable + +try: + import typings +except ImportError: + typings = False + + +# Tests for helper functions. +# --------------------------- + +class TestIsSimpleCallable: + + def test_method(self): + class Foo: + @classmethod + def classmethod(cls): + pass + + def valid(self): + pass + + def valid_kwargs(self, param='value'): + pass + + def invalid(self, param): + pass + + assert is_simple_callable(Foo.classmethod) + + # unbound methods + assert not is_simple_callable(Foo.valid) + assert not is_simple_callable(Foo.valid_kwargs) + assert not is_simple_callable(Foo.invalid) + + # bound methods + assert is_simple_callable(Foo().valid) + assert is_simple_callable(Foo().valid_kwargs) + assert not is_simple_callable(Foo().invalid) + + def test_function(self): + def simple(): + pass + + def valid(param='value', param2='value'): + pass + + def invalid(param, param2='value'): + pass + + assert is_simple_callable(simple) + assert is_simple_callable(valid) + assert not is_simple_callable(invalid) + + @unittest.skipUnless(typings, 'requires python 3.5') + def test_type_annotation(self): + # The annotation will otherwise raise a syntax error in python < 3.5 + exec("def valid(param: str='value'): pass", locals()) + valid = locals()['valid'] + + assert is_simple_callable(valid) # Tests for field keyword arguments and core functionality. From b3afcb25d9f0621091231212fcc8c92631e035e4 Mon Sep 17 00:00:00 2001 From: Ryan P Kilby Date: Thu, 22 Sep 2016 15:19:48 -0400 Subject: [PATCH 23/27] Drop python 3.2 support (EOL, Dropped by Django) --- .travis.yml | 1 - tox.ini | 3 +-- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index c9d9a1648..100a7cd8b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -14,7 +14,6 @@ env: - TOX_ENV=py35-django18 - TOX_ENV=py34-django18 - TOX_ENV=py33-django18 - - TOX_ENV=py32-django18 - TOX_ENV=py27-django18 - TOX_ENV=py27-django110 - TOX_ENV=py35-django110 diff --git a/tox.ini b/tox.ini index 1e8a7e5c4..c20021e4b 100644 --- a/tox.ini +++ b/tox.ini @@ -4,7 +4,7 @@ addopts=--tb=short [tox] envlist = py27-{lint,docs}, - {py27,py32,py33,py34,py35}-django18, + {py27,py33,py34,py35}-django18, {py27,py34,py35}-django19, {py27,py34,py35}-django110, {py27,py34,py35}-django{master} @@ -25,7 +25,6 @@ basepython = py35: python3.5 py34: python3.4 py33: python3.3 - py32: python3.2 py27: python2.7 [testenv:py27-lint] From b5167120764c997562de8fb9caf71249ae4a8c9e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Wed, 28 Sep 2016 12:15:46 +0100 Subject: [PATCH 24/27] schema_renderers= should *set* the renderers, not append to them. --- rest_framework/routers.py | 57 ++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 27 deletions(-) diff --git a/rest_framework/routers.py b/rest_framework/routers.py index 4eec70bda..64f110cd9 100644 --- a/rest_framework/routers.py +++ b/rest_framework/routers.py @@ -289,42 +289,42 @@ class DefaultRouter(SimpleRouter): self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES) super(DefaultRouter, self).__init__(*args, **kwargs) + def get_schema_root_view(self, api_urls=None): + """ + Return a schema root view. + """ + schema_renderers = self.schema_renderers + schema_generator = SchemaGenerator( + title=self.schema_title, + url=self.schema_url, + patterns=api_urls + ) + + class APISchemaView(views.APIView): + _ignore_model_permissions = True + renderer_classes = schema_renderers + + def get(self, request, *args, **kwargs): + schema = schema_generator.get_schema(request) + if schema is None: + raise exceptions.PermissionDenied() + return Response(schema) + + return APISchemaView.as_view() + def get_api_root_view(self, api_urls=None): """ - Return a view to use as the API root. + Return a basic root view. """ api_root_dict = OrderedDict() list_name = self.routes[0].name for prefix, viewset, basename in self.registry: api_root_dict[prefix] = list_name.format(basename=basename) - view_renderers = list(self.root_renderers) - schema_media_types = [] - - if api_urls and self.schema_title: - view_renderers += list(self.schema_renderers) - schema_generator = SchemaGenerator( - title=self.schema_title, - url=self.schema_url, - patterns=api_urls - ) - schema_media_types = [ - renderer.media_type - for renderer in self.schema_renderers - ] - - class APIRoot(views.APIView): + class APIRootView(views.APIView): _ignore_model_permissions = True - renderer_classes = view_renderers def get(self, request, *args, **kwargs): - if request.accepted_renderer.media_type in schema_media_types: - # Return a schema response. - schema = schema_generator.get_schema(request) - if schema is None: - raise exceptions.PermissionDenied() - return Response(schema) - # Return a plain {"name": "hyperlink"} response. ret = OrderedDict() namespace = request.resolver_match.namespace @@ -345,7 +345,7 @@ class DefaultRouter(SimpleRouter): return Response(ret) - return APIRoot.as_view() + return APIRootView.as_view() def get_urls(self): """ @@ -355,7 +355,10 @@ class DefaultRouter(SimpleRouter): urls = super(DefaultRouter, self).get_urls() if self.include_root_view: - view = self.get_api_root_view(api_urls=urls) + if self.schema_title: + view = self.get_schema_root_view(api_urls=urls) + else: + view = self.get_api_root_view(api_urls=urls) root_url = url(r'^$', view, name=self.root_view_name) urls.append(root_url) From 37b3475e5d048f7f7e751dee6bdb23695fe484bd Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Sep 2016 09:46:28 +0100 Subject: [PATCH 25/27] API client (#4424) --- README.md | 2 +- docs/api-guide/permissions.md | 7 +- docs/topics/release-notes.md | 25 + requirements/requirements-optionals.txt | 2 +- rest_framework/__init__.py | 2 +- rest_framework/fields.py | 2 +- rest_framework/relations.py | 6 +- rest_framework/renderers.py | 9 +- rest_framework/request.py | 5 + rest_framework/schemas.py | 7 + rest_framework/serializers.py | 8 +- .../static/rest_framework/js/csrf.js | 2 +- .../templates/rest_framework/admin.html | 1 + .../templates/rest_framework/base.html | 1 + .../vertical/checkbox_multiple.html | 4 +- rest_framework/test.py | 15 +- rest_framework/views.py | 20 +- tests/test_api_client.py | 452 ++++++++++++++++++ tests/test_request.py | 11 + tests/test_schemas.py | 1 + 20 files changed, 557 insertions(+), 25 deletions(-) create mode 100644 tests/test_api_client.py diff --git a/README.md b/README.md index 179f2891a..e1e252609 100644 --- a/README.md +++ b/README.md @@ -169,7 +169,7 @@ You may also want to [follow the author on Twitter][twitter]. # Security -If you believe you’ve found something in Django REST framework which has security implications, please **do not raise the issue in a public forum**. +If you believe you've found something in Django REST framework which has security implications, please **do not raise the issue in a public forum**. Send a description of the issue via email to [rest-framework-security@googlegroups.com][security-mail]. The project maintainers will then work with you to resolve any issues where required, prior to any public disclosure. diff --git a/docs/api-guide/permissions.md b/docs/api-guide/permissions.md index e0838e94a..7cdb59531 100644 --- a/docs/api-guide/permissions.md +++ b/docs/api-guide/permissions.md @@ -92,7 +92,7 @@ Or, if you're using the `@api_view` decorator with function based views. from rest_framework.permissions import IsAuthenticated from rest_framework.response import Response - @api_view('GET') + @api_view(['GET']) @permission_classes((IsAuthenticated, )) def example_view(request, format=None): content = { @@ -261,6 +261,10 @@ The [REST Condition][rest-condition] package is another extension for building c The [DRY Rest Permissions][dry-rest-permissions] package provides the ability to define different permissions for individual default and custom actions. This package is made for apps with permissions that are derived from relationships defined in the app's data model. It also supports permission checks being returned to a client app through the API's serializer. Additionally it supports adding permissions to the default and custom list actions to restrict the data they retrive per user. +## Django Rest Framework Roles + +The [Django Rest Framework Roles][django-rest-framework-roles] package makes it easier to parameterize your API over multiple types of users. + [cite]: https://developer.apple.com/library/mac/#documentation/security/Conceptual/AuthenticationAndAuthorizationGuide/Authorization/Authorization.html [authentication]: authentication.md [throttling]: throttling.md @@ -275,3 +279,4 @@ The [DRY Rest Permissions][dry-rest-permissions] package provides the ability to [composed-permissions]: https://github.com/niwibe/djangorestframework-composed-permissions [rest-condition]: https://github.com/caxap/rest_condition [dry-rest-permissions]: https://github.com/Helioscene/dry-rest-permissions +[django-rest-framework-roles]: https://github.com/computer-lab/django-rest-framework-roles diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 78a5a8ba9..24728a252 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -40,6 +40,18 @@ You can determine your currently installed version using `pip freeze`: ## 3.4.x series +### 3.4.5 + +**Date**: [19th August 2016][3.4.5-milestone] + +* Improve debug error handling. ([#4416][gh4416], [#4409][gh4409]) +* Allow custom CSRF_HEADER_NAME setting. ([#4415][gh4415], [#4410][gh4410]) +* Include .action attribute on viewsets when generating schemas. ([#4408][gh4408], [#4398][gh4398]) +* Do not include request.FILES items in request.POST. ([#4407][gh4407]) +* Fix rendering of checkbox multiple. ([#4403][gh4403]) +* Fix docstring of Field.get_default. ([#4404][gh4404]) +* Replace utf8 character with its ascii counterpart in README. ([#4412][gh4412]) + ### 3.4.4 **Date**: [12th August 2016][3.4.4-milestone] @@ -560,6 +572,7 @@ For older release notes, [please see the version 2.x documentation][old-release- [3.4.2-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.2+Release%22 [3.4.3-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.3+Release%22 [3.4.4-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.4+Release%22 +[3.4.5-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.5+Release%22 [gh2013]: https://github.com/tomchristie/django-rest-framework/issues/2013 @@ -1065,3 +1078,15 @@ For older release notes, [please see the version 2.x documentation][old-release- [gh4392]: https://github.com/tomchristie/django-rest-framework/issues/4392 [gh4393]: https://github.com/tomchristie/django-rest-framework/issues/4393 [gh4394]: https://github.com/tomchristie/django-rest-framework/issues/4394 + + +[gh4416]: https://github.com/tomchristie/django-rest-framework/issues/4416 +[gh4409]: https://github.com/tomchristie/django-rest-framework/issues/4409 +[gh4415]: https://github.com/tomchristie/django-rest-framework/issues/4415 +[gh4410]: https://github.com/tomchristie/django-rest-framework/issues/4410 +[gh4408]: https://github.com/tomchristie/django-rest-framework/issues/4408 +[gh4398]: https://github.com/tomchristie/django-rest-framework/issues/4398 +[gh4407]: https://github.com/tomchristie/django-rest-framework/issues/4407 +[gh4403]: https://github.com/tomchristie/django-rest-framework/issues/4403 +[gh4404]: https://github.com/tomchristie/django-rest-framework/issues/4404 +[gh4412]: https://github.com/tomchristie/django-rest-framework/issues/4412 diff --git a/requirements/requirements-optionals.txt b/requirements/requirements-optionals.txt index 20436e6b4..afade0aa0 100644 --- a/requirements/requirements-optionals.txt +++ b/requirements/requirements-optionals.txt @@ -2,4 +2,4 @@ markdown==2.6.4 django-guardian==1.4.3 django-filter==0.13.0 -coreapi==1.32.0 +coreapi==2.0.0 diff --git a/rest_framework/__init__.py b/rest_framework/__init__.py index 999c5de31..3f8736c25 100644 --- a/rest_framework/__init__.py +++ b/rest_framework/__init__.py @@ -8,7 +8,7 @@ ______ _____ _____ _____ __ """ __title__ = 'Django REST framework' -__version__ = '3.4.4' +__version__ = '3.4.5' __author__ = 'Tom Christie' __license__ = 'BSD 2-Clause' __copyright__ = 'Copyright 2011-2016 Tom Christie' diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 8f12b2df4..f76e4e801 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -432,7 +432,7 @@ class Field(object): is provided for this field. If a default has not been set for this field then this will simply - return `empty`, indicating that no value should be set in the + raise `SkipField`, indicating that no value should be set in the validated data for this field. """ if self.default is empty or getattr(self.root, 'partial', False): diff --git a/rest_framework/relations.py b/rest_framework/relations.py index 4b6b3bea4..65c4c0318 100644 --- a/rest_framework/relations.py +++ b/rest_framework/relations.py @@ -10,7 +10,7 @@ from django.core.urlresolvers import ( from django.db.models import Manager from django.db.models.query import QuerySet from django.utils import six -from django.utils.encoding import smart_text +from django.utils.encoding import python_2_unicode_compatible, smart_text from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ @@ -47,6 +47,7 @@ class Hyperlink(six.text_type): is_hyperlink = True +@python_2_unicode_compatible class PKOnlyObject(object): """ This is a mock object, used for when we only need the pk of the object @@ -56,6 +57,9 @@ class PKOnlyObject(object): def __init__(self, pk): self.pk = pk + def __str__(self): + return "%s" % self.pk + # We assume that 'validators' are intended for the child serializer, # rather than the parent serializer. diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 371cd6ec7..11e9fb960 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -645,6 +645,12 @@ class BrowsableAPIRenderer(BaseRenderer): else: paginator = None + csrf_cookie_name = settings.CSRF_COOKIE_NAME + csrf_header_name = getattr(settings, 'CSRF_HEADER_NAME', 'HTTP_X_CSRFToken') # Fallback for Django 1.8 + if csrf_header_name.startswith('HTTP_'): + csrf_header_name = csrf_header_name[5:] + csrf_header_name = csrf_header_name.replace('_', '-') + context = { 'content': self.get_content(renderer, data, accepted_media_type, renderer_context), 'view': view, @@ -675,7 +681,8 @@ class BrowsableAPIRenderer(BaseRenderer): 'display_edit_forms': bool(response.status_code != 403), 'api_settings': api_settings, - 'csrf_cookie_name': settings.CSRF_COOKIE_NAME, + 'csrf_cookie_name': csrf_cookie_name, + 'csrf_header_name': csrf_header_name } return context diff --git a/rest_framework/request.py b/rest_framework/request.py index f5738bfd5..355cccad7 100644 --- a/rest_framework/request.py +++ b/rest_framework/request.py @@ -391,3 +391,8 @@ class Request(object): '`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` ' 'since version 3.0, and has been fully removed as of version 3.2.' ) + + def force_plaintext_errors(self, value): + # Hack to allow our exception handler to force choice of + # plaintext or html error responses. + self._request.is_ajax = lambda: value diff --git a/rest_framework/schemas.py b/rest_framework/schemas.py index 0618e94fd..c9834c64d 100644 --- a/rest_framework/schemas.py +++ b/rest_framework/schemas.py @@ -79,6 +79,13 @@ class SchemaGenerator(object): view.kwargs = {} view.format_kwarg = None + actions = getattr(callback, 'actions', None) + if actions is not None: + if method == 'OPTIONS': + view.action = 'metadata' + else: + view.action = actions.get(method.lower()) + if request is not None: view.request = clone_request(request, method) try: diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index 41412af8a..4d1ed63ae 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -12,6 +12,7 @@ response content is handled by parsers and renderers. """ from __future__ import unicode_literals +import traceback import warnings from django.db import models @@ -870,19 +871,20 @@ class ModelSerializer(Serializer): try: instance = ModelClass.objects.create(**validated_data) - except TypeError as exc: + except TypeError: + tb = traceback.format_exc() msg = ( 'Got a `TypeError` when calling `%s.objects.create()`. ' 'This may be because you have a writable field on the ' 'serializer class that is not a valid argument to ' '`%s.objects.create()`. You may need to make the field ' 'read-only, or override the %s.create() method to handle ' - 'this correctly.\nOriginal exception text was: %s.' % + 'this correctly.\nOriginal exception was:\n %s' % ( ModelClass.__name__, ModelClass.__name__, self.__class__.__name__, - exc + tb ) ) raise TypeError(msg) diff --git a/rest_framework/static/rest_framework/js/csrf.js b/rest_framework/static/rest_framework/js/csrf.js index f8ab4428c..97c8d0124 100644 --- a/rest_framework/static/rest_framework/js/csrf.js +++ b/rest_framework/static/rest_framework/js/csrf.js @@ -46,7 +46,7 @@ $.ajaxSetup({ // Send the token to same-origin, relative URLs only. // Send the token only if the method warrants CSRF protection // Using the CSRFToken value acquired earlier - xhr.setRequestHeader("X-CSRFToken", csrftoken); + xhr.setRequestHeader(window.drf.csrfHeaderName, csrftoken); } } }); diff --git a/rest_framework/templates/rest_framework/admin.html b/rest_framework/templates/rest_framework/admin.html index 89af81ef7..eb2b8f1c7 100644 --- a/rest_framework/templates/rest_framework/admin.html +++ b/rest_framework/templates/rest_framework/admin.html @@ -232,6 +232,7 @@ {% block script %} diff --git a/rest_framework/templates/rest_framework/base.html b/rest_framework/templates/rest_framework/base.html index 4c1136087..989a086ea 100644 --- a/rest_framework/templates/rest_framework/base.html +++ b/rest_framework/templates/rest_framework/base.html @@ -263,6 +263,7 @@ {% block script %} diff --git a/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html b/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html index b933f4ff5..7a43b3f58 100644 --- a/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html +++ b/rest_framework/templates/rest_framework/vertical/checkbox_multiple.html @@ -9,7 +9,7 @@
{% for key, text in field.choices.items %} {% endfor %} @@ -18,7 +18,7 @@ {% for key, text in field.choices.items %}
diff --git a/rest_framework/test.py b/rest_framework/test.py index e17c19a43..b8e486b21 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -16,7 +16,7 @@ from django.utils import six from django.utils.encoding import force_bytes from django.utils.http import urlencode -from rest_framework.compat import requests +from rest_framework.compat import coreapi, requests from rest_framework.settings import api_settings @@ -60,7 +60,10 @@ if requests is not None: # Set request content, if any exists. if request.body is not None: - kwargs['data'] = request.body + if hasattr(request.body, 'read'): + kwargs['data'] = request.body.read() + else: + kwargs['data'] = request.body if 'content-type' in request.headers: kwargs['content_type'] = request.headers['content-type'] @@ -126,6 +129,14 @@ def get_requests_client(): return DjangoTestSession() +def get_api_client(): + assert coreapi is not None, 'coreapi must be installed' + session = get_requests_client() + return coreapi.Client(transports=[ + coreapi.transports.HTTPTransport(session=session) + ]) + + class APIRequestFactory(DjangoRequestFactory): renderer_classes_list = api_settings.TEST_REQUEST_RENDERER_CLASSES default_format = api_settings.TEST_REQUEST_DEFAULT_FORMAT diff --git a/rest_framework/views.py b/rest_framework/views.py index b86bb7eaa..15d8c6cde 100644 --- a/rest_framework/views.py +++ b/rest_framework/views.py @@ -3,17 +3,14 @@ Provides an APIView class that is the base of all views in REST framework. """ from __future__ import unicode_literals -import sys - from django.conf import settings from django.core.exceptions import PermissionDenied from django.db import models from django.http import Http404 -from django.http.response import HttpResponse, HttpResponseBase +from django.http.response import HttpResponseBase from django.utils import six from django.utils.encoding import smart_text from django.utils.translation import ugettext_lazy as _ -from django.views import debug from django.views.decorators.csrf import csrf_exempt from django.views.generic import View @@ -95,11 +92,6 @@ def exception_handler(exc, context): set_rollback() return Response(data, status=status.HTTP_403_FORBIDDEN) - # throw django's error page if debug is True - if settings.DEBUG: - exception_reporter = debug.ExceptionReporter(context.get('request'), *sys.exc_info()) - return HttpResponse(exception_reporter.get_traceback_html(), status=500) - return None @@ -439,11 +431,19 @@ class APIView(View): response = exception_handler(exc, context) if response is None: - raise + self.raise_uncaught_exception(exc) response.exception = True return response + def raise_uncaught_exception(self, exc): + if settings.DEBUG: + request = self.request + renderer_format = getattr(request.accepted_renderer, 'format') + use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin') + request.force_plaintext_errors(use_plaintext_traceback) + raise + # Note: Views are made CSRF exempt from within `as_view` as to prevent # accidental removal of this exemption in cases where `dispatch` needs to # be overridden. diff --git a/tests/test_api_client.py b/tests/test_api_client.py new file mode 100644 index 000000000..9daf3f3fe --- /dev/null +++ b/tests/test_api_client.py @@ -0,0 +1,452 @@ +from __future__ import unicode_literals + +import os +import tempfile +import unittest + +from django.conf.urls import url +from django.http import HttpResponse +from django.test import override_settings + +from rest_framework.compat import coreapi +from rest_framework.parsers import FileUploadParser +from rest_framework.renderers import CoreJSONRenderer +from rest_framework.response import Response +from rest_framework.test import APITestCase, get_api_client +from rest_framework.views import APIView + + +def get_schema(): + return coreapi.Document( + url='https://api.example.com/', + title='Example API', + content={ + 'simple_link': coreapi.Link('/example/', description='example link'), + 'location': { + 'query': coreapi.Link('/example/', fields=[ + coreapi.Field(name='example', description='example field') + ]), + 'form': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example'), + ]), + 'body': coreapi.Link('/example/', action='post', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'path': coreapi.Link('/example/{id}', fields=[ + coreapi.Field(name='id', location='path') + ]) + }, + 'encoding': { + 'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example') + ]), + 'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example') + ]), + 'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[ + coreapi.Field(name='example', location='body') + ]), + 'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[ + coreapi.Field(name='example', location='body') + ]), + }, + 'response': { + 'download': coreapi.Link('/download/'), + 'text': coreapi.Link('/text/') + } + } + ) + + +def _iterlists(querydict): + if hasattr(querydict, 'iterlists'): + return querydict.iterlists() + return querydict.lists() + + +def _get_query_params(request): + # Return query params in a plain dict, using a list value if more + # than one item is present for a given key. + return { + key: (value[0] if len(value) == 1 else value) + for key, value in + _iterlists(request.query_params) + } + + +def _get_data(request): + if not isinstance(request.data, dict): + return request.data + # Coerce multidict into regular dict, and remove files to + # make assertions simpler. + if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'): + # Use a list value if a QueryDict contains multiple items for a key. + return { + key: value[0] if len(value) == 1 else value + for key, value in _iterlists(request.data) + if key not in request.FILES + } + return { + key: value + for key, value in request.data.items() + if key not in request.FILES + } + + +def _get_files(request): + if not request.FILES: + return {} + return { + key: {'name': value.name, 'content': value.read()} + for key, value in request.FILES.items() + } + + +class SchemaView(APIView): + renderer_classes = [CoreJSONRenderer] + + def get(self, request): + schema = get_schema() + return Response(schema) + + +class ListView(APIView): + def get(self, request): + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + def post(self, request): + if request.content_type: + content_type = request.content_type.split(';')[0] + else: + content_type = None + + return Response({ + 'method': request.method, + 'query_params': _get_query_params(request), + 'data': _get_data(request), + 'files': _get_files(request), + 'content_type': content_type + }) + + +class DetailView(APIView): + def get(self, request, id): + return Response({ + 'id': id, + 'method': request.method, + 'query_params': _get_query_params(request) + }) + + +class UploadView(APIView): + parser_classes = [FileUploadParser] + + def post(self, request): + return Response({ + 'method': request.method, + 'files': _get_files(request), + 'content_type': request.content_type + }) + + +class DownloadView(APIView): + def get(self, request): + return HttpResponse('some file content', content_type='image/png') + + +class TextView(APIView): + def get(self, request): + return HttpResponse('123', content_type='text/plain') + + +urlpatterns = [ + url(r'^$', SchemaView.as_view()), + url(r'^example/$', ListView.as_view()), + url(r'^example/(?P[0-9]+)/$', DetailView.as_view()), + url(r'^upload/$', UploadView.as_view()), + url(r'^download/$', DownloadView.as_view()), + url(r'^text/$', TextView.as_view()), +] + + +@unittest.skipUnless(coreapi, 'coreapi not installed') +@override_settings(ROOT_URLCONF='tests.test_api_client') +class APIClientTests(APITestCase): + def test_api_client(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + assert schema.title == 'Example API' + assert schema.url == 'https://api.example.com/' + assert schema['simple_link'].description == 'example link' + assert schema['location']['query'].fields[0].description == 'example field' + data = client.action(schema, ['simple_link']) + expected = { + 'method': 'GET', + 'query_params': {} + } + assert data == expected + + def test_query_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': 123}) + expected = { + 'method': 'GET', + 'query_params': {'example': '123'} + } + assert data == expected + + def test_query_params_with_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'GET', + 'query_params': {'example': ['1', '2', '3']} + } + assert data == expected + + def test_form_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'form'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': {'example': 123}, + 'files': {} + } + assert data == expected + + def test_body_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'body'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/json', + 'query_params': {}, + 'data': 123, + 'files': {} + } + assert data == expected + + def test_path_params(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['location', 'path'], params={'id': 123}) + expected = { + 'method': 'GET', + 'query_params': {}, + 'id': '123' + } + assert data == expected + + def test_multipart_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'multipart'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': name, 'content': 'example file content'}} + } + assert data == expected + + def test_multipart_encoding_no_file(self): + # When no file is included, multipart encoding should still be used. + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': 123}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_multipart_encoding_string_file_content(self): + # Test for `coreapi.utils.File` support. + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File(name='example.txt', content='123') + data = client.action(schema, ['encoding', 'multipart'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {}, + 'files': {'example': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + def test_multipart_encoding_in_body(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'} + data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example}) + + expected = { + 'method': 'POST', + 'content_type': 'multipart/form-data', + 'query_params': {}, + 'data': {'bar': 'abc'}, + 'files': {'foo': {'name': 'example.txt', 'content': '123'}} + } + assert data == expected + + # URLencoded + + def test_urlencoded_encoding(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': '123'}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_multiple_values(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'example': ['1', '2', '3']}, + 'files': {} + } + assert data == expected + + def test_urlencoded_encoding_in_body(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}}) + expected = { + 'method': 'POST', + 'content_type': 'application/x-www-form-urlencoded', + 'query_params': {}, + 'data': {'foo': '123', 'bar': 'true'}, + 'files': {} + } + assert data == expected + + # Raw uploads + + def test_raw_upload(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + temp = tempfile.NamedTemporaryFile() + temp.write(b'example file content') + temp.flush() + + with open(temp.name, 'rb') as upload: + name = os.path.basename(upload.name) + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': upload}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': name, 'content': 'example file content'}}, + 'content_type': 'application/octet-stream' + } + assert data == expected + + def test_raw_upload_string_file_content(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/plain' + } + assert data == expected + + def test_raw_upload_explicit_content_type(self): + from coreapi.utils import File + + client = get_api_client() + schema = client.get('http://api.example.com/') + + example = File('example.txt', '123', 'text/html') + data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example}) + + expected = { + 'method': 'POST', + 'files': {'file': {'name': 'example.txt', 'content': '123'}}, + 'content_type': 'text/html' + } + assert data == expected + + # Responses + + def test_text_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'text']) + + expected = '123' + assert data == expected + + def test_download_response(self): + client = get_api_client() + schema = client.get('http://api.example.com/') + + data = client.action(schema, ['response', 'download']) + assert data.basename == 'download.png' + assert data.read() == b'some file content' diff --git a/tests/test_request.py b/tests/test_request.py index dee636d76..dbfa695fd 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -7,6 +7,7 @@ from django.conf.urls import url from django.contrib.auth import authenticate, login, logout from django.contrib.auth.models import User from django.contrib.sessions.middleware import SessionMiddleware +from django.core.files.uploadedfile import SimpleUploadedFile from django.test import TestCase, override_settings from django.utils import six @@ -78,6 +79,16 @@ class TestContentParsing(TestCase): request.parsers = (FormParser(), MultiPartParser()) self.assertEqual(list(request.POST.items()), list(data.items())) + def test_request_POST_with_files(self): + """ + Ensure request.POST returns no content for POST request with file content. + """ + upload = SimpleUploadedFile("file.txt", b"file_content") + request = Request(factory.post('/', {'upload': upload})) + request.parsers = (FormParser(), MultiPartParser()) + self.assertEqual(list(request.POST.keys()), []) + self.assertEqual(list(request.FILES.keys()), ['upload']) + def test_standard_behaviour_determines_form_content_PUT(self): """ Ensure request.data returns content for PUT request with form content. diff --git a/tests/test_schemas.py b/tests/test_schemas.py index 81b796c35..c866e09be 100644 --- a/tests/test_schemas.py +++ b/tests/test_schemas.py @@ -49,6 +49,7 @@ class ExampleViewSet(ModelViewSet): def get_serializer(self, *args, **kwargs): assert self.request + assert self.action return super(ExampleViewSet, self).get_serializer(*args, **kwargs) From 61b11890495d5bb014d97cd4a8ce3b1cf951454e Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Sep 2016 10:24:30 +0100 Subject: [PATCH 26/27] Fix release notes --- docs/topics/release-notes.md | 3 --- 1 file changed, 3 deletions(-) diff --git a/docs/topics/release-notes.md b/docs/topics/release-notes.md index 6ef6cb83a..446abdd14 100644 --- a/docs/topics/release-notes.md +++ b/docs/topics/release-notes.md @@ -594,11 +594,8 @@ For older release notes, [please see the version 2.x documentation][old-release- [3.4.3-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.3+Release%22 [3.4.4-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.4+Release%22 [3.4.5-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.5+Release%22 -<<<<<<< HEAD -======= [3.4.6-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.6+Release%22 [3.4.7-milestone]: https://github.com/tomchristie/django-rest-framework/issues?q=milestone%3A%223.4.7+Release%22 ->>>>>>> master [gh2013]: https://github.com/tomchristie/django-rest-framework/issues/2013 From b689a3bdaa521b14d13266c866180d19585babe4 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 29 Sep 2016 12:03:14 +0100 Subject: [PATCH 27/27] Add note about 'User account is disabled.' vs 'Unable to log in' --- rest_framework/authtoken/serializers.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/rest_framework/authtoken/serializers.py b/rest_framework/authtoken/serializers.py index df0c48b86..90d3bd96e 100644 --- a/rest_framework/authtoken/serializers.py +++ b/rest_framework/authtoken/serializers.py @@ -16,6 +16,9 @@ class AuthTokenSerializer(serializers.Serializer): user = authenticate(username=username, password=password) if user: + # From Django 1.10 onwards the `authenticate` call simply + # returns `None` for is_active=False users. + # (Assuming the default `ModelBackend` authentication backend.) if not user.is_active: msg = _('User account is disabled.') raise serializers.ValidationError(msg)