mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-26 11:33:59 +03:00
Merge branch 'master' of https://github.com/tomchristie/django-rest-framework
This commit is contained in:
commit
399ac70b83
|
@ -300,7 +300,7 @@ The only thing needed to make the `OAuth2Authentication` class work is to insert
|
||||||
|
|
||||||
The command line to test the authentication looks like:
|
The command line to test the authentication looks like:
|
||||||
|
|
||||||
curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/?client_id=YOUR_CLIENT_ID\&client_secret=YOUR_CLIENT_SECRET
|
curl -H "Authorization: Bearer <your-access-token>" http://localhost:8000/api/
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
|
|
@ -40,6 +40,12 @@ You can determine your currently installed version using `pip freeze`:
|
||||||
|
|
||||||
## 2.2.x series
|
## 2.2.x series
|
||||||
|
|
||||||
|
### Master
|
||||||
|
|
||||||
|
* OAuth2 authentication no longer requires unneccessary URL parameters in addition to the token.
|
||||||
|
* URL hyperlinking in browseable API now handles more cases correctly.
|
||||||
|
* Bugfix: Fix regression with DjangoFilterBackend not worthing correctly with single object views.
|
||||||
|
|
||||||
### 2.2.5
|
### 2.2.5
|
||||||
|
|
||||||
**Date**: 26th March 2013
|
**Date**: 26th March 2013
|
||||||
|
|
|
@ -2,14 +2,16 @@
|
||||||
Provides a set of pluggable authentication policies.
|
Provides a set of pluggable authentication policies.
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
import base64
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
from django.contrib.auth import authenticate
|
from django.contrib.auth import authenticate
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from rest_framework import exceptions, HTTP_HEADER_ENCODING
|
from rest_framework import exceptions, HTTP_HEADER_ENCODING
|
||||||
from rest_framework.compat import CsrfViewMiddleware
|
from rest_framework.compat import CsrfViewMiddleware
|
||||||
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
|
from rest_framework.compat import oauth, oauth_provider, oauth_provider_store
|
||||||
from rest_framework.compat import oauth2_provider, oauth2_provider_forms, oauth2_provider_backends
|
from rest_framework.compat import oauth2_provider, oauth2_provider_forms
|
||||||
from rest_framework.authtoken.models import Token
|
from rest_framework.authtoken.models import Token
|
||||||
import base64
|
|
||||||
|
|
||||||
|
|
||||||
def get_authorization_header(request):
|
def get_authorization_header(request):
|
||||||
|
@ -315,21 +317,15 @@ class OAuth2Authentication(BaseAuthentication):
|
||||||
Authenticate the request, given the access token.
|
Authenticate the request, given the access token.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Authenticate the client
|
try:
|
||||||
oauth2_client_form = oauth2_provider_forms.ClientAuthForm(request.REQUEST)
|
token = oauth2_provider.models.AccessToken.objects.select_related('user')
|
||||||
if not oauth2_client_form.is_valid():
|
# TODO: Change to timezone aware datetime when oauth2_provider add
|
||||||
raise exceptions.AuthenticationFailed('Client could not be validated')
|
# support to it.
|
||||||
client = oauth2_client_form.cleaned_data.get('client')
|
token = token.get(token=access_token, expires__gt=datetime.now())
|
||||||
|
except oauth2_provider.models.AccessToken.DoesNotExist:
|
||||||
# Retrieve the `OAuth2AccessToken` instance from the access_token
|
|
||||||
auth_backend = oauth2_provider_backends.AccessTokenBackend()
|
|
||||||
token = auth_backend.authenticate(access_token, client)
|
|
||||||
if token is None:
|
|
||||||
raise exceptions.AuthenticationFailed('Invalid token')
|
raise exceptions.AuthenticationFailed('Invalid token')
|
||||||
|
|
||||||
user = token.user
|
if not token.user.is_active:
|
||||||
|
|
||||||
if not user.is_active:
|
|
||||||
msg = 'User inactive or deleted: %s' % user.username
|
msg = 'User inactive or deleted: %s' % user.username
|
||||||
raise exceptions.AuthenticationFailed(msg)
|
raise exceptions.AuthenticationFailed(msg)
|
||||||
|
|
||||||
|
|
|
@ -395,6 +395,37 @@ except ImportError:
|
||||||
kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
|
kw = dict((k, int(v)) for k, v in kw.iteritems() if v is not None)
|
||||||
return datetime.datetime(**kw)
|
return datetime.datetime(**kw)
|
||||||
|
|
||||||
|
|
||||||
|
# smart_urlquote is new on Django 1.4
|
||||||
|
try:
|
||||||
|
from django.utils.html import smart_urlquote
|
||||||
|
except ImportError:
|
||||||
|
try:
|
||||||
|
from urllib.parse import quote, urlsplit, urlunsplit
|
||||||
|
except ImportError: # Python 2
|
||||||
|
from urllib import quote
|
||||||
|
from urlparse import urlsplit, urlunsplit
|
||||||
|
|
||||||
|
def smart_urlquote(url):
|
||||||
|
"Quotes a URL if it isn't already quoted."
|
||||||
|
# Handle IDN before quoting.
|
||||||
|
scheme, netloc, path, query, fragment = urlsplit(url)
|
||||||
|
try:
|
||||||
|
netloc = netloc.encode('idna').decode('ascii') # IDN -> ACE
|
||||||
|
except UnicodeError: # invalid domain part
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
url = urlunsplit((scheme, netloc, path, query, fragment))
|
||||||
|
|
||||||
|
# An URL is considered unquoted if it contains no % characters or
|
||||||
|
# contains a % not followed by two hexadecimal digits. See #9655.
|
||||||
|
if '%' not in url or unquoted_percents_re.search(url):
|
||||||
|
# See http://bugs.python.org/issue2637
|
||||||
|
url = quote(force_bytes(url), safe=b'!*\'();:@&=+$,/?#[]~')
|
||||||
|
|
||||||
|
return force_text(url)
|
||||||
|
|
||||||
|
|
||||||
# Markdown is optional
|
# Markdown is optional
|
||||||
try:
|
try:
|
||||||
import markdown
|
import markdown
|
||||||
|
@ -445,14 +476,12 @@ except ImportError:
|
||||||
# OAuth 2 support is optional
|
# OAuth 2 support is optional
|
||||||
try:
|
try:
|
||||||
import provider.oauth2 as oauth2_provider
|
import provider.oauth2 as oauth2_provider
|
||||||
from provider.oauth2 import backends as oauth2_provider_backends
|
|
||||||
from provider.oauth2 import models as oauth2_provider_models
|
from provider.oauth2 import models as oauth2_provider_models
|
||||||
from provider.oauth2 import forms as oauth2_provider_forms
|
from provider.oauth2 import forms as oauth2_provider_forms
|
||||||
from provider import scope as oauth2_provider_scope
|
from provider import scope as oauth2_provider_scope
|
||||||
from provider import constants as oauth2_constants
|
from provider import constants as oauth2_constants
|
||||||
except ImportError:
|
except ImportError:
|
||||||
oauth2_provider = None
|
oauth2_provider = None
|
||||||
oauth2_provider_backends = None
|
|
||||||
oauth2_provider_models = None
|
oauth2_provider_models = None
|
||||||
oauth2_provider_forms = None
|
oauth2_provider_forms = None
|
||||||
oauth2_provider_scope = None
|
oauth2_provider_scope = None
|
||||||
|
|
|
@ -55,6 +55,6 @@ class DjangoFilterBackend(BaseFilterBackend):
|
||||||
filter_class = self.get_filter_class(view)
|
filter_class = self.get_filter_class(view)
|
||||||
|
|
||||||
if filter_class:
|
if filter_class:
|
||||||
return filter_class(request.QUERY_PARAMS, queryset=queryset)
|
return filter_class(request.QUERY_PARAMS, queryset=queryset).qs
|
||||||
|
|
||||||
return queryset
|
return queryset
|
||||||
|
|
|
@ -4,11 +4,8 @@ from django.core.urlresolvers import reverse, NoReverseMatch
|
||||||
from django.http import QueryDict
|
from django.http import QueryDict
|
||||||
from django.utils.html import escape
|
from django.utils.html import escape
|
||||||
from django.utils.safestring import SafeData, mark_safe
|
from django.utils.safestring import SafeData, mark_safe
|
||||||
from rest_framework.compat import urlparse
|
from rest_framework.compat import urlparse, force_text, six, smart_urlquote
|
||||||
from rest_framework.compat import force_text
|
import re, string
|
||||||
from rest_framework.compat import six
|
|
||||||
import re
|
|
||||||
import string
|
|
||||||
|
|
||||||
register = template.Library()
|
register = template.Library()
|
||||||
|
|
||||||
|
@ -112,22 +109,6 @@ def replace_query_param(url, key, val):
|
||||||
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
|
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
|
||||||
|
|
||||||
|
|
||||||
# Bunch of stuff cloned from urlize
|
|
||||||
LEADING_PUNCTUATION = ['(', '<', '<', '"', "'"]
|
|
||||||
TRAILING_PUNCTUATION = ['.', ',', ')', '>', '\n', '>', '"', "'"]
|
|
||||||
DOTS = ['·', '*', '\xe2\x80\xa2', '•', '•', '•']
|
|
||||||
unencoded_ampersands_re = re.compile(r'&(?!(\w+|#\d+);)')
|
|
||||||
word_split_re = re.compile(r'(\s+)')
|
|
||||||
punctuation_re = re.compile('^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % \
|
|
||||||
('|'.join([re.escape(x) for x in LEADING_PUNCTUATION]),
|
|
||||||
'|'.join([re.escape(x) for x in TRAILING_PUNCTUATION])))
|
|
||||||
simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
|
|
||||||
link_target_attribute_re = re.compile(r'(<a [^>]*?)target=[^\s>]+')
|
|
||||||
html_gunk_re = re.compile(r'(?:<br clear="all">|<i><\/i>|<b><\/b>|<em><\/em>|<strong><\/strong>|<\/?smallcaps>|<\/?uppercase>)', re.IGNORECASE)
|
|
||||||
hard_coded_bullets_re = re.compile(r'((?:<p>(?:%s).*?[a-zA-Z].*?</p>\s*)+)' % '|'.join([re.escape(x) for x in DOTS]), re.DOTALL)
|
|
||||||
trailing_empty_content_re = re.compile(r'(?:<p>(?: |\s|<br \/>)*?</p>\s*)+\Z')
|
|
||||||
|
|
||||||
|
|
||||||
# And the template tags themselves...
|
# And the template tags themselves...
|
||||||
|
|
||||||
@register.simple_tag
|
@register.simple_tag
|
||||||
|
@ -195,15 +176,25 @@ def add_class(value, css_class):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# Bunch of stuff cloned from urlize
|
||||||
|
TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "'"]
|
||||||
|
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('<', '>'),
|
||||||
|
('"', '"'), ("'", "'")]
|
||||||
|
word_split_re = re.compile(r'(\s+)')
|
||||||
|
simple_url_re = re.compile(r'^https?://\w', re.IGNORECASE)
|
||||||
|
simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
|
||||||
|
simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
|
||||||
|
|
||||||
|
|
||||||
@register.filter
|
@register.filter
|
||||||
def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):
|
def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=True):
|
||||||
"""
|
"""
|
||||||
Converts any URLs in text into clickable links.
|
Converts any URLs in text into clickable links.
|
||||||
|
|
||||||
Works on http://, https://, www. links and links ending in .org, .net or
|
Works on http://, https://, www. links, and also on links ending in one of
|
||||||
.com. Links can have trailing punctuation (periods, commas, close-parens)
|
the original seven gTLDs (.com, .edu, .gov, .int, .mil, .net, and .org).
|
||||||
and leading punctuation (opening parens) and it'll still do the right
|
Links can have trailing punctuation (periods, commas, close-parens) and
|
||||||
thing.
|
leading punctuation (opening parens) and it'll still do the right thing.
|
||||||
|
|
||||||
If trim_url_limit is not None, the URLs in link text longer than this limit
|
If trim_url_limit is not None, the URLs in link text longer than this limit
|
||||||
will truncated to trim_url_limit-3 characters and appended with an elipsis.
|
will truncated to trim_url_limit-3 characters and appended with an elipsis.
|
||||||
|
@ -216,24 +207,41 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
|
||||||
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
|
trim_url = lambda x, limit=trim_url_limit: limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
|
||||||
safe_input = isinstance(text, SafeData)
|
safe_input = isinstance(text, SafeData)
|
||||||
words = word_split_re.split(force_text(text))
|
words = word_split_re.split(force_text(text))
|
||||||
nofollow_attr = nofollow and ' rel="nofollow"' or ''
|
|
||||||
for i, word in enumerate(words):
|
for i, word in enumerate(words):
|
||||||
match = None
|
match = None
|
||||||
if '.' in word or '@' in word or ':' in word:
|
if '.' in word or '@' in word or ':' in word:
|
||||||
match = punctuation_re.match(word)
|
# Deal with punctuation.
|
||||||
if match:
|
lead, middle, trail = '', word, ''
|
||||||
lead, middle, trail = match.groups()
|
for punctuation in TRAILING_PUNCTUATION:
|
||||||
|
if middle.endswith(punctuation):
|
||||||
|
middle = middle[:-len(punctuation)]
|
||||||
|
trail = punctuation + trail
|
||||||
|
for opening, closing in WRAPPING_PUNCTUATION:
|
||||||
|
if middle.startswith(opening):
|
||||||
|
middle = middle[len(opening):]
|
||||||
|
lead = lead + opening
|
||||||
|
# Keep parentheses at the end only if they're balanced.
|
||||||
|
if (middle.endswith(closing)
|
||||||
|
and middle.count(closing) == middle.count(opening) + 1):
|
||||||
|
middle = middle[:-len(closing)]
|
||||||
|
trail = closing + trail
|
||||||
|
|
||||||
# Make URL we want to point to.
|
# Make URL we want to point to.
|
||||||
url = None
|
url = None
|
||||||
if middle.startswith('http://') or middle.startswith('https://'):
|
nofollow_attr = ' rel="nofollow"' if nofollow else ''
|
||||||
url = middle
|
if simple_url_re.match(middle):
|
||||||
elif middle.startswith('www.') or ('@' not in middle and \
|
url = smart_urlquote(middle)
|
||||||
middle and middle[0] in string.ascii_letters + string.digits and \
|
elif simple_url_2_re.match(middle):
|
||||||
(middle.endswith('.org') or middle.endswith('.net') or middle.endswith('.com'))):
|
url = smart_urlquote('http://%s' % middle)
|
||||||
url = 'http://%s' % middle
|
elif not ':' in middle and simple_email_re.match(middle):
|
||||||
elif '@' in middle and not ':' in middle and simple_email_re.match(middle):
|
local, domain = middle.rsplit('@', 1)
|
||||||
url = 'mailto:%s' % middle
|
try:
|
||||||
|
domain = domain.encode('idna').decode('ascii')
|
||||||
|
except UnicodeError:
|
||||||
|
continue
|
||||||
|
url = 'mailto:%s@%s' % (local, domain)
|
||||||
nofollow_attr = ''
|
nofollow_attr = ''
|
||||||
|
|
||||||
# Make link.
|
# Make link.
|
||||||
if url:
|
if url:
|
||||||
trimmed = trim_url(middle)
|
trimmed = trim_url(middle)
|
||||||
|
@ -251,4 +259,4 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
|
||||||
words[i] = mark_safe(word)
|
words[i] = mark_safe(word)
|
||||||
elif autoescape:
|
elif autoescape:
|
||||||
words[i] = escape(word)
|
words[i] = escape(word)
|
||||||
return mark_safe(''.join(words))
|
return ''.join(words)
|
||||||
|
|
|
@ -466,17 +466,13 @@ class OAuth2Tests(TestCase):
|
||||||
def _create_authorization_header(self, token=None):
|
def _create_authorization_header(self, token=None):
|
||||||
return "Bearer {0}".format(token or self.access_token.token)
|
return "Bearer {0}".format(token or self.access_token.token)
|
||||||
|
|
||||||
def _client_credentials_params(self):
|
|
||||||
return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
|
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
def test_get_form_with_wrong_authorization_header_token_type_failing(self):
|
def test_get_form_with_wrong_authorization_header_token_type_failing(self):
|
||||||
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
|
"""Ensure that a wrong token type lead to the correct HTTP error status code"""
|
||||||
auth = "Wrong token-type-obsviously"
|
auth = "Wrong token-type-obsviously"
|
||||||
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
@ -485,8 +481,7 @@ class OAuth2Tests(TestCase):
|
||||||
auth = "Bearer wrong token format"
|
auth = "Bearer wrong token format"
|
||||||
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
@ -495,33 +490,21 @@ class OAuth2Tests(TestCase):
|
||||||
auth = "Bearer wrong-token"
|
auth = "Bearer wrong-token"
|
||||||
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.get('/oauth2-test/', {}, HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 401)
|
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
|
||||||
def test_get_form_with_wrong_client_data_failing_auth(self):
|
|
||||||
"""Ensure GETing form over OAuth with incorrect client credentials fails"""
|
|
||||||
auth = self._create_authorization_header()
|
|
||||||
params = self._client_credentials_params()
|
|
||||||
params['client_id'] += 'a'
|
|
||||||
response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 401)
|
self.assertEqual(response.status_code, 401)
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
def test_get_form_passing_auth(self):
|
def test_get_form_passing_auth(self):
|
||||||
"""Ensure GETing form over OAuth with correct client credentials succeed"""
|
"""Ensure GETing form over OAuth with correct client credentials succeed"""
|
||||||
auth = self._create_authorization_header()
|
auth = self._create_authorization_header()
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.get('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
def test_post_form_passing_auth(self):
|
def test_post_form_passing_auth(self):
|
||||||
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
|
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
|
||||||
auth = self._create_authorization_header()
|
auth = self._create_authorization_header()
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
@ -529,16 +512,14 @@ class OAuth2Tests(TestCase):
|
||||||
"""Ensure POSTing when there is no OAuth access token in db fails"""
|
"""Ensure POSTing when there is no OAuth access token in db fails"""
|
||||||
self.access_token.delete()
|
self.access_token.delete()
|
||||||
auth = self._create_authorization_header()
|
auth = self._create_authorization_header()
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
def test_post_form_with_refresh_token_failing_auth(self):
|
def test_post_form_with_refresh_token_failing_auth(self):
|
||||||
"""Ensure POSTing with refresh token instead of access token fails"""
|
"""Ensure POSTing with refresh token instead of access token fails"""
|
||||||
auth = self._create_authorization_header(token=self.refresh_token.token)
|
auth = self._create_authorization_header(token=self.refresh_token.token)
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
@ -547,8 +528,7 @@ class OAuth2Tests(TestCase):
|
||||||
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
|
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
|
||||||
self.access_token.save()
|
self.access_token.save()
|
||||||
auth = self._create_authorization_header()
|
auth = self._create_authorization_header()
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.post('/oauth2-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
|
||||||
self.assertIn('Invalid token', response.content)
|
self.assertIn('Invalid token', response.content)
|
||||||
|
|
||||||
|
@ -559,10 +539,9 @@ class OAuth2Tests(TestCase):
|
||||||
read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
|
read_only_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['read']
|
||||||
read_only_access_token.save()
|
read_only_access_token.save()
|
||||||
auth = self._create_authorization_header(token=read_only_access_token.token)
|
auth = self._create_authorization_header(token=read_only_access_token.token)
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.get('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.get('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
|
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
|
||||||
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
|
||||||
|
|
||||||
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
@unittest.skipUnless(oauth2_provider, 'django-oauth2-provider not installed')
|
||||||
|
@ -572,6 +551,5 @@ class OAuth2Tests(TestCase):
|
||||||
read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
|
read_write_access_token.scope = oauth2_provider_scope.SCOPE_NAME_DICT['write']
|
||||||
read_write_access_token.save()
|
read_write_access_token.save()
|
||||||
auth = self._create_authorization_header(token=read_write_access_token.token)
|
auth = self._create_authorization_header(token=read_write_access_token.token)
|
||||||
params = self._client_credentials_params()
|
response = self.csrf_client.post('/oauth2-with-scope-test/', HTTP_AUTHORIZATION=auth)
|
||||||
response = self.csrf_client.post('/oauth2-with-scope-test/', params, HTTP_AUTHORIZATION=auth)
|
|
||||||
self.assertEqual(response.status_code, 200)
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
import datetime
|
import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
|
from django.core.urlresolvers import reverse
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from django.test.client import RequestFactory
|
from django.test.client import RequestFactory
|
||||||
from django.utils import unittest
|
from django.utils import unittest
|
||||||
from rest_framework import generics, status, filters
|
from rest_framework import generics, status, filters
|
||||||
from rest_framework.compat import django_filters
|
from rest_framework.compat import django_filters, patterns, url
|
||||||
from rest_framework.tests.models import FilterableItem, BasicModel
|
from rest_framework.tests.models import FilterableItem, BasicModel
|
||||||
|
|
||||||
factory = RequestFactory()
|
factory = RequestFactory()
|
||||||
|
@ -46,12 +47,21 @@ if django_filters:
|
||||||
filter_class = MisconfiguredFilter
|
filter_class = MisconfiguredFilter
|
||||||
filter_backend = filters.DjangoFilterBackend
|
filter_backend = filters.DjangoFilterBackend
|
||||||
|
|
||||||
|
class FilterClassDetailView(generics.RetrieveAPIView):
|
||||||
|
model = FilterableItem
|
||||||
|
filter_class = SeveralFieldsFilter
|
||||||
|
filter_backend = filters.DjangoFilterBackend
|
||||||
|
|
||||||
class IntegrationTestFiltering(TestCase):
|
urlpatterns = patterns('',
|
||||||
"""
|
url(r'^(?P<pk>\d+)/$', FilterClassDetailView.as_view(), name='detail-view'),
|
||||||
Integration tests for filtered list views.
|
url(r'^$', FilterClassRootView.as_view(), name='root-view'),
|
||||||
"""
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CommonFilteringTestCase(TestCase):
|
||||||
|
def _serialize_object(self, obj):
|
||||||
|
return {'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
||||||
|
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
"""
|
"""
|
||||||
Create 10 FilterableItem instances.
|
Create 10 FilterableItem instances.
|
||||||
|
@ -65,10 +75,16 @@ class IntegrationTestFiltering(TestCase):
|
||||||
|
|
||||||
self.objects = FilterableItem.objects
|
self.objects = FilterableItem.objects
|
||||||
self.data = [
|
self.data = [
|
||||||
{'id': obj.id, 'text': obj.text, 'decimal': obj.decimal, 'date': obj.date}
|
self._serialize_object(obj)
|
||||||
for obj in self.objects.all()
|
for obj in self.objects.all()
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestFiltering(CommonFilteringTestCase):
|
||||||
|
"""
|
||||||
|
Integration tests for filtered list views.
|
||||||
|
"""
|
||||||
|
|
||||||
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
||||||
def test_get_filtered_fields_root_view(self):
|
def test_get_filtered_fields_root_view(self):
|
||||||
"""
|
"""
|
||||||
|
@ -167,3 +183,50 @@ class IntegrationTestFiltering(TestCase):
|
||||||
request = factory.get('/?integer=%s' % search_integer)
|
request = factory.get('/?integer=%s' % search_integer)
|
||||||
response = view(request).render()
|
response = view(request).render()
|
||||||
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
|
||||||
|
|
||||||
|
class IntegrationTestDetailFiltering(CommonFilteringTestCase):
|
||||||
|
"""
|
||||||
|
Integration tests for filtered detail views.
|
||||||
|
"""
|
||||||
|
urls = 'rest_framework.tests.filterset'
|
||||||
|
|
||||||
|
def _get_url(self, item):
|
||||||
|
return reverse('detail-view', kwargs=dict(pk=item.pk))
|
||||||
|
|
||||||
|
@unittest.skipUnless(django_filters, 'django-filters not installed')
|
||||||
|
def test_get_filtered_detail_view(self):
|
||||||
|
"""
|
||||||
|
GET requests to filtered RetrieveAPIView that have a filter_class set
|
||||||
|
should return filtered results.
|
||||||
|
"""
|
||||||
|
item = self.objects.all()[0]
|
||||||
|
data = self._serialize_object(item)
|
||||||
|
|
||||||
|
# Basic test with no filter.
|
||||||
|
response = self.client.get(self._get_url(item))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, data)
|
||||||
|
|
||||||
|
# Tests that the decimal filter set that should fail.
|
||||||
|
search_decimal = Decimal('4.25')
|
||||||
|
high_item = self.objects.filter(decimal__gt=search_decimal)[0]
|
||||||
|
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(high_item), param=search_decimal))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
|
||||||
|
|
||||||
|
# Tests that the decimal filter set that should succeed.
|
||||||
|
search_decimal = Decimal('4.25')
|
||||||
|
low_item = self.objects.filter(decimal__lt=search_decimal)[0]
|
||||||
|
low_item_data = self._serialize_object(low_item)
|
||||||
|
response = self.client.get('{url}?decimal={param}'.format(url=self._get_url(low_item), param=search_decimal))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, low_item_data)
|
||||||
|
|
||||||
|
# Tests that multiple filters works.
|
||||||
|
search_decimal = Decimal('5.25')
|
||||||
|
search_date = datetime.date(2012, 10, 2)
|
||||||
|
valid_item = self.objects.filter(decimal__lt=search_decimal, date__gt=search_date)[0]
|
||||||
|
valid_item_data = self._serialize_object(valid_item)
|
||||||
|
response = self.client.get('{url}?decimal={decimal}&date={date}'.format(url=self._get_url(valid_item), decimal=search_decimal, date=search_date))
|
||||||
|
self.assertEqual(response.status_code, status.HTTP_200_OK)
|
||||||
|
self.assertEqual(response.data, valid_item_data)
|
||||||
|
|
|
@ -129,16 +129,6 @@ class IntegrationTestPaginationAndFiltering(TestCase):
|
||||||
view = FilterFieldsRootView.as_view()
|
view = FilterFieldsRootView.as_view()
|
||||||
|
|
||||||
EXPECTED_NUM_QUERIES = 2
|
EXPECTED_NUM_QUERIES = 2
|
||||||
if django.VERSION < (1, 4):
|
|
||||||
# On Django 1.3 we need to use django-filter 0.5.4
|
|
||||||
#
|
|
||||||
# The filter objects there don't expose a `.count()` method,
|
|
||||||
# which means we only make a single query *but* it's a single
|
|
||||||
# query across *all* of the queryset, instead of a COUNT and then
|
|
||||||
# a SELECT with a LIMIT.
|
|
||||||
#
|
|
||||||
# Although this is fewer queries, it's actually a regression.
|
|
||||||
EXPECTED_NUM_QUERIES = 1
|
|
||||||
|
|
||||||
request = factory.get('/?decimal=15.20')
|
request = factory.get('/?decimal=15.20')
|
||||||
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
|
with self.assertNumQueries(EXPECTED_NUM_QUERIES):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user