Add tests for OAuth2 authentication

This commit is contained in:
Pierre Dulac 2013-03-01 02:06:20 +01:00
parent 02ee6e5bf0
commit 468b5e43e2

View File

@ -1,4 +1,5 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.core.urlresolvers import reverse
from django.contrib.auth.models import User from django.contrib.auth.models import User
from django.http import HttpResponse from django.http import HttpResponse
from django.test import Client, TestCase from django.test import Client, TestCase
@ -6,11 +7,15 @@ from rest_framework import HTTP_HEADER_ENCODING
from rest_framework import permissions from rest_framework import permissions
from rest_framework import status from rest_framework import status
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication from rest_framework.authentication import TokenAuthentication, BasicAuthentication, SessionAuthentication, OAuth2Authentication
from rest_framework.compat import patterns from rest_framework.compat import patterns, url, include
from rest_framework.compat import oauth2
from rest_framework.compat import oauth2_provider
from rest_framework.views import APIView from rest_framework.views import APIView
import json import json
import base64 import base64
import datetime
import unittest
class MockView(APIView): class MockView(APIView):
@ -22,11 +27,16 @@ class MockView(APIView):
def put(self, request): def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3}) return HttpResponse({'a': 1, 'b': 2, 'c': 3})
def get(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
urlpatterns = patterns('', urlpatterns = patterns('',
(r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])), (r'^session/$', MockView.as_view(authentication_classes=[SessionAuthentication])),
(r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])), (r'^basic/$', MockView.as_view(authentication_classes=[BasicAuthentication])),
(r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])), (r'^token/$', MockView.as_view(authentication_classes=[TokenAuthentication])),
(r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'), (r'^auth-token/$', 'rest_framework.authtoken.views.obtain_auth_token'),
url(r'^oauth2/', include('provider.oauth2.urls', namespace = 'oauth2')),
url(r'^oauth2-test/$', MockView.as_view(authentication_classes=[OAuth2Authentication])),
) )
@ -187,3 +197,99 @@ class TokenAuthTests(TestCase):
{'username': self.username, 'password': self.password}) {'username': self.username, 'password': self.password})
self.assertEqual(response.status_code, status.HTTP_200_OK) self.assertEqual(response.status_code, status.HTTP_200_OK)
self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key) self.assertEqual(json.loads(response.content.decode('ascii'))['token'], self.key)
class OAuth2Tests(TestCase):
"""OAuth 2.0 authentication"""
urls = 'rest_framework.tests.authentication'
def setUp(self):
self.csrf_client = Client(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(self.username, self.email, self.password)
self.CLIENT_ID = 'client_key'
self.CLIENT_SECRET = 'client_secret'
self.ACCESS_TOKEN = "access_token"
self.REFRESH_TOKEN = "refresh_token"
self.oauth2_client = oauth2.models.Client.objects.create(
client_id=self.CLIENT_ID,
client_secret=self.CLIENT_SECRET,
redirect_uri='',
client_type=0,
name='example',
user=None,
)
self.access_token = oauth2.models.AccessToken.objects.create(
token=self.ACCESS_TOKEN,
client=self.oauth2_client,
user=self.user,
)
self.refresh_token = oauth2.models.RefreshToken.objects.create(
user=self.user,
access_token=self.access_token,
client=self.oauth2_client
)
def _create_authorization_header(self, token=None):
return "Bearer {0}".format(token or self.access_token.token)
def _client_credentials_params(self):
return {'client_id': self.CLIENT_ID, 'client_secret': self.CLIENT_SECRET}
@unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
def test_get_form_with_wrong_client_data_failing_auth(self):
"""Ensure GETing form over OAuth with incorrect client credentials fails"""
auth = self._create_authorization_header()
params = self._client_credentials_params()
params['client_id'] += 'a'
response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 401)
@unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
def test_get_form_passing_auth(self):
"""Ensure GETing form over OAuth with correct client credentials succeed"""
auth = self._create_authorization_header()
params = self._client_credentials_params()
response = self.csrf_client.get('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
def test_post_form_passing_auth(self):
"""Ensure POSTing form over OAuth with correct credentials passes and does not require CSRF"""
auth = self._create_authorization_header()
params = self._client_credentials_params()
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
self.assertEqual(response.status_code, 200)
@unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
def test_post_form_token_removed_failing_auth(self):
"""Ensure POSTing when there is no OAuth access token in db fails"""
self.access_token.delete()
auth = self._create_authorization_header()
params = self._client_credentials_params()
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
def test_post_form_with_refresh_token_failing_auth(self):
"""Ensure POSTing with refresh token instead of access token fails"""
auth = self._create_authorization_header(token=self.refresh_token.token)
params = self._client_credentials_params()
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
@unittest.skipUnless(oauth2, 'django-oauth2-provider not installed')
def test_post_form_with_expired_access_token_failing_auth(self):
"""Ensure POSTing with expired access token fails with an 'Invalid token' error"""
self.access_token.expires = datetime.datetime.now() - datetime.timedelta(seconds=10) # 10 seconds late
self.access_token.save()
auth = self._create_authorization_header()
params = self._client_credentials_params()
response = self.csrf_client.post('/oauth2-test/', params, HTTP_AUTHORIZATION=auth)
self.assertIn(response.status_code, (status.HTTP_401_UNAUTHORIZED, status.HTTP_403_FORBIDDEN))
self.assertIn('Invalid token', response.content)