mirror of
				https://github.com/encode/django-rest-framework.git
				synced 2025-11-01 00:17:40 +03:00 
			
		
		
		
	Merge branch 'master' of https://github.com/tomchristie/django-rest-framework
This commit is contained in:
		
						commit
						399ac70b83
					
				|  | @ -300,7 +300,7 @@ The only thing needed to make the `OAuth2Authentication` class work is to insert | |||
| 
 | ||||
| The command line to test the authentication looks like: | ||||
| 
 | ||||
|     curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/?client_id=YOUR_CLIENT_ID\&client_secret=YOUR_CLIENT_SECRET | ||||
|     curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/ | ||||
| 
 | ||||
| --- | ||||
| 
 | ||||
|  |  | |||
|  | @ -40,6 +40,12 @@ You can determine your currently installed version using `pip freeze`: | |||
| 
 | ||||
| ## 2.2.x series | ||||
| 
 | ||||
| ### Master | ||||
| 
 | ||||
| * OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token. | ||||
| * URL hyperlinking in browseable API now handles more cases correctly. | ||||
| * Bugfix: Fix regression with DjangoFilterBackend not worthing correctly with single object views. | ||||
| 
 | ||||
| ### 2.2.5 | ||||
| 
 | ||||
| **Date**: 26th March 2013 | ||||
|  |  | |||
|  | @ -2,14 +2,16 @@ | |||
| Provides a set of pluggable authentication policies. | ||||
| """ | ||||
| from __future__ import unicode_literals | ||||
| import base64 | ||||
| from datetime import datetime | ||||
| 
 | ||||
| from django.contrib.auth import authenticate | ||||
| from django.core.exceptions import ImproperlyConfigured | ||||
| from rest_framework import exceptions, HTTP_HEADER_ENCODING | ||||
| from rest_framework.compat import CsrfViewMiddleware | ||||
| from rest_framework.compat import oauth, oauth_provider, oauth_provider_store | ||||
| from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends | ||||
| from rest_framework.compat import oauth2_provider, oauth2_provider_forms | ||||
| from rest_framework.authtoken.models import Token | ||||
| import base64 | ||||
| 
 | ||||
| 
 | ||||
| def get_authorization_header(request): | ||||
|  | @ -315,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication): | |||
|         Authenticate the request, given the access token. | ||||
|         """ | ||||
| 
 | ||||
|         # Authenticate the client | ||||
|         oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST) | ||||
|         if not oauth2_client_form.is_valid(): | ||||
|             raise exceptions.AuthenticationFailed('Client could not be validated') | ||||
|         client = oauth2_client_form.cleaned_data.get('client') | ||||
| 
 | ||||
|         # Retrieve the `OAuth2AccessToken` instance from the access_token | ||||
|         auth_backend = oauth2_provider_backends.AccessTokenBackend() | ||||
|         token = auth_backend.authenticate(access_token, client) | ||||
|         if token is None: | ||||
|         try: | ||||
|             token = oauth2_provider.models.AccessToken.objects.select_related('user') | ||||
|             # TODO: Change to timezone aware datetime when oauth2_provider add | ||||
|             # support to it. | ||||
|             token = token.get(token=access_token, expires__gt=datetime.now()) | ||||
|         except oauth2_provider.models.AccessToken.DoesNotExist: | ||||
|             raise exceptions.AuthenticationFailed('Invalid token') | ||||
| 
 | ||||
|         user = token.user | ||||
| 
 | ||||
|         if not user.is_active: | ||||
|         if not token.user.is_active: | ||||
|             msg = 'User inactive or deleted: %s' % user.username | ||||
|             raise exceptions.AuthenticationFailed(msg) | ||||
| 
 | ||||
|  |  | |||
|  | @ -395,6 +395,37 @@ except ImportError: | |||
|             kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None) | ||||
|             return datetime.datetime(**kw) | ||||
| 
 | ||||
| 
 | ||||
| # smart_urlquote is new on Django 1.4 | ||||
| try: | ||||
|     from django.utils.html import smart_urlquote | ||||
| except ImportError: | ||||
|     try: | ||||
|         from urllib.parse import quote, urlsplit, urlunsplit | ||||
|     except ImportError:     # Python 2 | ||||
|         from urllib import quote | ||||
|         from urlparse import urlsplit, urlunsplit | ||||
| 
 | ||||
|     def smart_urlquote(url): | ||||
|         "Quotes a URL if it isn't already quoted." | ||||
|         # Handle IDN before quoting. | ||||
|         scheme, netloc, path, query, fragment = urlsplit(url) | ||||
|         try: | ||||
|             netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE | ||||
|         except UnicodeError: # invalid domain part | ||||
|             pass | ||||
|         else: | ||||
|             url = urlunsplit((scheme, netloc, path, query, fragment)) | ||||
| 
 | ||||
|         # An URL is considered unquoted if it contains no % characters or | ||||
|         # contains a % not followed by two hexadecimal digits. See #9655. | ||||
|         if '%' not in url or unquoted_percents_re.search(url): | ||||
|             # See http://bugs.python.org/issue2637 | ||||
|             url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~') | ||||
| 
 | ||||
|         return force_text(url) | ||||
| 
 | ||||
| 
 | ||||
| # Markdown is optional | ||||
| try: | ||||
|     import markdown | ||||
|  | @ -445,14 +476,12 @@ except ImportError: | |||
| # OAuth 2 support is optional | ||||
| try: | ||||
|     import provider.oauth2 as oauth2_provider | ||||
|     from provider.oauth2 import backends as oauth2_provider_backends | ||||
|     from provider.oauth2 import models as oauth2_provider_models | ||||
|     from provider.oauth2 import forms as oauth2_provider_forms | ||||
|     from provider import scope as oauth2_provider_scope | ||||
|     from provider import constants as oauth2_constants | ||||
| except ImportError: | ||||
|     oauth2_provider = None | ||||
|     oauth2_provider_backends = None | ||||
|     oauth2_provider_models = None | ||||
|     oauth2_provider_forms = None | ||||
|     oauth2_provider_scope = None | ||||
|  |  | |||
|  | @ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend): | |||
|         filter_class = self.get_filter_class(view) | ||||
| 
 | ||||
|         if filter_class: | ||||
|             return filter_class(request.QUERY_PARAMS, queryset=queryset) | ||||
|             return filter_class(request.QUERY_PARAMS, queryset=queryset).qs | ||||
| 
 | ||||
|         return queryset | ||||
|  |  | |||
|  | @ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch | |||
| from django.http import QueryDict | ||||
| from django.utils.html import escape | ||||
| from django.utils.safestring import SafeData, mark_safe | ||||
| from rest_framework.compat import urlparse | ||||
| from rest_framework.compat import force_text | ||||
| from rest_framework.compat import six | ||||
| import re | ||||
| import string | ||||
| from rest_framework.compat import urlparse, force_text, six, smart_urlquote | ||||
| import re, string | ||||
| 
 | ||||
| register = template.Library() | ||||
| 
 | ||||
|  | @ -112,22 +109,6 @@ def replace_query_param(url, key, val): | |||
| class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])') | ||||
| 
 | ||||
| 
 | ||||
| # Bunch of stuff cloned from urlize | ||||
| LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"] | ||||
| TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"] | ||||
| DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•'] | ||||
| unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)') | ||||
| word_split_re = re.compile(r'(\s+)') | ||||
| punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \ | ||||
|     ('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]), | ||||
|     '|'.join([re.escape(x) for x in TRAILING_PUNCTUATION]))) | ||||
| simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$') | ||||
| link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+') | ||||
| html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE) | ||||
| hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL) | ||||
| trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z') | ||||
| 
 | ||||
| 
 | ||||
| # And the template tags themselves... | ||||
| 
 | ||||
| @register.simple_tag | ||||
|  | @ -195,15 +176,25 @@ def add_class(value, css_class): | |||
|     return value | ||||
| 
 | ||||
| 
 | ||||
| # Bunch of stuff cloned from urlize | ||||
| TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"] | ||||
| WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'), | ||||
|                         ('"', '"'), ("'", "'")] | ||||
| word_split_re = re.compile(r'(\s+)') | ||||
| simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE) | ||||
| simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE) | ||||
| simple_email_re = re.compile(r'^\S+@\S+\.\S+$') | ||||
| 
 | ||||
| 
 | ||||
| @register.filter | ||||
| def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True): | ||||
|     """ | ||||
|     Converts any URLs in text into clickable links. | ||||
| 
 | ||||
|     Works on http://, https://, www. links and links ending in .org, .net or | ||||
|     .com. Links can have trailing punctuation (periods, commas, close-parens) | ||||
|     and leading punctuation (opening parens) and it'll still do the right | ||||
|     thing. | ||||
|     Works on http://, https://, www. links, and also on links ending in one of | ||||
|     the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org). | ||||
|     Links can have trailing punctuation (periods, commas, close-parens) and | ||||
|     leading punctuation (opening parens) and it'll still do the right thing. | ||||
| 
 | ||||
|     If trim_url_limit is not None, the URLs in link text longer than this limit | ||||
|     will truncated to trim_url_limit-3 characters and appended with an elipsis. | ||||
|  | @ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru | |||
|     trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x | ||||
|     safe_input = isinstance(text, SafeData) | ||||
|     words = word_split_re.split(force_text(text)) | ||||
|     nofollow_attr = nofollow and ' rel="nofollow"' or '' | ||||
|     for i, word in enumerate(words): | ||||
|         match = None | ||||
|         if '.' in word or '@' in word or ':' in word: | ||||
|             match = punctuation_re.match(word) | ||||
|         if match: | ||||
|             lead, middle, trail = match.groups() | ||||
|             # Deal with punctuation. | ||||
|             lead, middle, trail = '', word, '' | ||||
|             for punctuation in TRAILING_PUNCTUATION: | ||||
|                 if middle.endswith(punctuation): | ||||
|                     middle = middle[:-len(punctuation)] | ||||
|                     trail = punctuation + trail | ||||
|             for opening, closing in WRAPPING_PUNCTUATION: | ||||
|                 if middle.startswith(opening): | ||||
|                     middle = middle[len(opening):] | ||||
|                     lead = lead + opening | ||||
|                 # Keep parentheses at the end only if they're balanced. | ||||
|                 if (middle.endswith(closing) | ||||
|                     and middle.count(closing) == middle.count(opening) + 1): | ||||
|                     middle = middle[:-len(closing)] | ||||
|                     trail = closing + trail | ||||
| 
 | ||||
|             # Make URL we want to point to. | ||||
|             url = None | ||||
|             if middle.startswith('http://') or middle.startswith('https://'): | ||||
|                 url = middle | ||||
|             elif middle.startswith('www.') or ('@' not in middle and \ | ||||
|                     middle and middle[0] in string.ascii_letters + string.digits and \ | ||||
|                     (middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))): | ||||
|                 url = 'http://%s' % middle | ||||
|             elif '@' in middle and not ':' in middle and simple_email_re.match(middle): | ||||
|                 url = 'mailto:%s' % middle | ||||
|             nofollow_attr = ' rel="nofollow"' if nofollow else '' | ||||
|             if simple_url_re.match(middle): | ||||
|                 url = smart_urlquote(middle) | ||||
|             elif simple_url_2_re.match(middle): | ||||
|                 url = smart_urlquote('http://%s' % middle) | ||||
|             elif not ':' in middle and simple_email_re.match(middle): | ||||
|                 local, domain = middle.rsplit('@', 1) | ||||
|                 try: | ||||
|                     domain = domain.encode('idna').decode('ascii') | ||||
|                 except UnicodeError: | ||||
|                     continue | ||||
|                 url = 'mailto:%s@%s' % (local, domain) | ||||
|                 nofollow_attr = '' | ||||
| 
 | ||||
|             # Make link. | ||||
|             if url: | ||||
|                 trimmed = trim_url(middle) | ||||
|  | @ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru | |||
|             words[i] = mark_safe(word) | ||||
|         elif autoescape: | ||||
|             words[i] = escape(word) | ||||
|     return mark_safe(''.join(words)) | ||||
|     return ''.join(words) | ||||
|  |  | |||
|  | @ -466,17 +466,13 @@ class OAuth2Tests(TestCase): | |||
|     def _create_authorization_header(self, token=None): | ||||
|         return "Bearer {0}".format(token or self.access_token.token) | ||||
| 
 | ||||
|     def _client_credentials_params(self): | ||||
|         return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET} | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|     def test_get_form_with_wrong_authorization_header_token_type_failing(self): | ||||
|         """Ensure that a wrong token type lead to the correct HTTP error status code""" | ||||
|         auth = "Wrong token-type-obsviously" | ||||
|         response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 401) | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 401) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|  | @ -485,8 +481,7 @@ class OAuth2Tests(TestCase): | |||
|         auth = "Bearer wrong token format" | ||||
|         response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 401) | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 401) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|  | @ -495,33 +490,21 @@ class OAuth2Tests(TestCase): | |||
|         auth = "Bearer wrong-token" | ||||
|         response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 401) | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 401) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|     def test_get_form_with_wrong_client_data_failing_auth(self): | ||||
|         """Ensure GETing form over OAuth with incorrect client credentials fails""" | ||||
|         auth = self._create_authorization_header() | ||||
|         params = self._client_credentials_params() | ||||
|         params['client_id'] += 'a' | ||||
|         response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 401) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|     def test_get_form_passing_auth(self): | ||||
|         """Ensure GETing form over OAuth with correct client credentials succeed""" | ||||
|         auth = self._create_authorization_header() | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|     def test_post_form_passing_auth(self): | ||||
|         """Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF""" | ||||
|         auth = self._create_authorization_header() | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|  | @ -529,16 +512,14 @@ class OAuth2Tests(TestCase): | |||
|         """Ensure POSTing when there is no OAuth access token in db fails""" | ||||
|         self.access_token.delete() | ||||
|         auth = self._create_authorization_header() | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|     def test_post_form_with_refresh_token_failing_auth(self): | ||||
|         """Ensure POSTing with refresh token instead of access token fails""" | ||||
|         auth = self._create_authorization_header(token=self.refresh_token.token) | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|  | @ -547,8 +528,7 @@ class OAuth2Tests(TestCase): | |||
|         self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10)  # 10 seconds late | ||||
|         self.access_token.save() | ||||
|         auth = self._create_authorization_header() | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN)) | ||||
|         self.assertIn('Invalid token', response.content) | ||||
| 
 | ||||
|  | @ -559,10 +539,9 @@ class OAuth2Tests(TestCase): | |||
|         read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read'] | ||||
|         read_only_access_token.save() | ||||
|         auth = self._create_authorization_header(token=read_only_access_token.token) | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|         response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN) | ||||
| 
 | ||||
|     @unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed') | ||||
|  | @ -572,6 +551,5 @@ class OAuth2Tests(TestCase): | |||
|         read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write'] | ||||
|         read_write_access_token.save() | ||||
|         auth = self._create_authorization_header(token=read_write_access_token.token) | ||||
|         params = self._client_credentials_params() | ||||
|         response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth) | ||||
|         response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth) | ||||
|         self.assertEqual(response.status_code, 200) | ||||
|  |  | |||
|  | @ -1,11 +1,12 @@ | |||
| from __future__ import unicode_literals | ||||
| import datetime | ||||
| from decimal import Decimal | ||||
| from django.core.urlresolvers import reverse | ||||
| from django.test import TestCase | ||||
| from django.test.client import RequestFactory | ||||
| from django.utils import unittest | ||||
| from rest_framework import generics, status, filters | ||||
| from rest_framework.compat import django_filters | ||||
| from rest_framework.compat import django_filters, patterns, url | ||||
| from rest_framework.tests.models import FilterableItem, BasicModel | ||||
| 
 | ||||
| factory = RequestFactory() | ||||
|  | @ -46,12 +47,21 @@ if django_filters: | |||
|         filter_class = MisconfiguredFilter | ||||
|         filter_backend = filters.DjangoFilterBackend | ||||
| 
 | ||||
|     class FilterClassDetailView(generics.RetrieveAPIView): | ||||
|         model = FilterableItem | ||||
|         filter_class = SeveralFieldsFilter | ||||
|         filter_backend = filters.DjangoFilterBackend | ||||
| 
 | ||||
| class IntegrationTestFiltering(TestCase): | ||||
|     """ | ||||
|     Integration tests for filtered list views. | ||||
|     """ | ||||
|     urlpatterns = patterns('', | ||||
|         url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'), | ||||
|         url(r'^$', FilterClassRootView.as_view(), name='root-view'), | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| class CommonFilteringTestCase(TestCase): | ||||
|     def _serialize_object(self, obj): | ||||
|         return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} | ||||
|      | ||||
|     def setUp(self): | ||||
|         """ | ||||
|         Create 10 FilterableItem instances. | ||||
|  | @ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase): | |||
| 
 | ||||
|         self.objects = FilterableItem.objects | ||||
|         self.data = [ | ||||
|             {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date} | ||||
|             self._serialize_object(obj) | ||||
|             for obj in self.objects.all() | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
| class IntegrationTestFiltering(CommonFilteringTestCase): | ||||
|     """ | ||||
|     Integration tests for filtered list views. | ||||
|     """ | ||||
| 
 | ||||
|     @unittest.skipUnless(django_filters, 'django-filters not installed') | ||||
|     def test_get_filtered_fields_root_view(self): | ||||
|         """ | ||||
|  | @ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase): | |||
|         request = factory.get('/?integer=%s' % search_integer) | ||||
|         response = view(request).render() | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
| 
 | ||||
| 
 | ||||
| class IntegrationTestDetailFiltering(CommonFilteringTestCase): | ||||
|     """ | ||||
|     Integration tests for filtered detail views. | ||||
|     """ | ||||
|     urls = 'rest_framework.tests.filterset' | ||||
|      | ||||
|     def _get_url(self, item): | ||||
|         return reverse('detail-view', kwargs=dict(pk=item.pk)) | ||||
| 
 | ||||
|     @unittest.skipUnless(django_filters, 'django-filters not installed') | ||||
|     def test_get_filtered_detail_view(self): | ||||
|         """ | ||||
|         GET requests to filtered RetrieveAPIView that have a filter_class set | ||||
|         should return filtered results. | ||||
|         """ | ||||
|         item = self.objects.all()[0] | ||||
|         data = self._serialize_object(item) | ||||
| 
 | ||||
|         # Basic test with no filter. | ||||
|         response = self.client.get(self._get_url(item)) | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         self.assertEqual(response.data, data) | ||||
| 
 | ||||
|         # Tests that the decimal filter set that should fail. | ||||
|         search_decimal = Decimal('4.25') | ||||
|         high_item = self.objects.filter(decimal__gt=search_decimal)[0] | ||||
|         response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal)) | ||||
|         self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND) | ||||
| 
 | ||||
|         # Tests that the decimal filter set that should succeed. | ||||
|         search_decimal = Decimal('4.25') | ||||
|         low_item = self.objects.filter(decimal__lt=search_decimal)[0] | ||||
|         low_item_data = self._serialize_object(low_item) | ||||
|         response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal)) | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         self.assertEqual(response.data, low_item_data) | ||||
|          | ||||
|         # Tests that multiple filters works. | ||||
|         search_decimal = Decimal('5.25') | ||||
|         search_date = datetime.date(2012, 10, 2) | ||||
|         valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0] | ||||
|         valid_item_data = self._serialize_object(valid_item) | ||||
|         response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date)) | ||||
|         self.assertEqual(response.status_code, status.HTTP_200_OK) | ||||
|         self.assertEqual(response.data, valid_item_data) | ||||
|  |  | |||
|  | @ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase): | |||
|         view = FilterFieldsRootView.as_view() | ||||
| 
 | ||||
|         EXPECTED_NUM_QUERIES = 2 | ||||
|         if django.VERSION < (1, 4): | ||||
|             # On Django 1.3 we need to use django-filter 0.5.4 | ||||
|             # | ||||
|             # The filter objects there don't expose a `.count()` method, | ||||
|             # which means we only make a single query *but* it's a single | ||||
|             # query across *all* of the queryset, instead of a COUNT and then | ||||
|             # a SELECT with a LIMIT. | ||||
|             # | ||||
|             # Although this is fewer queries, it's actually a regression. | ||||
|             EXPECTED_NUM_QUERIES = 1 | ||||
| 
 | ||||
|         request = factory.get('/?decimal=15.20') | ||||
|         with self.assertNumQueries(EXPECTED_NUM_QUERIES): | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user