diff --git a/tests/models.py b/tests/models.py index e5d49a0a5..08c7fbd3e 100644 --- a/tests/models.py +++ b/tests/models.py @@ -2,6 +2,7 @@ from __future__ import unicode_literals import uuid +from django.contrib.auth.models import User from django.db import models from django.utils.translation import ugettext_lazy as _ @@ -87,3 +88,8 @@ class OneToOnePKSource(RESTFrameworkModel): target = models.OneToOneField( OneToOneTarget, primary_key=True, related_name='required_source', on_delete=models.CASCADE) + + +class Hat(RESTFrameworkModel): + user = models.OneToOneField(User) + some_key = models.CharField(max_length=100) diff --git a/tests/test_testing.py b/tests/test_testing.py index 1af6ef02e..f80859f0c 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -14,6 +14,7 @@ from rest_framework.response import Response from rest_framework.test import ( APIClient, APIRequestFactory, force_authenticate ) +from tests.models import Hat @api_view(['GET', 'POST']) @@ -49,11 +50,26 @@ def post_view(request): return Response(serializer.validated_data) +@api_view(['GET']) +def get_my_hat(request): + if hasattr(request.user, 'hat'): + return Response({ + 'id': request.user.hat.id, + 'some_key': request.user.hat.some_key, + }) + else: + return Response( + {}, + status=404, + ) + + urlpatterns = [ url(r'^view/$', view), url(r'^session-view/$', session_view), url(r'^redirect-view/$', redirect_view), - url(r'^post-view/$', post_view) + url(r'^post-view/$', post_view), + url(r'^my_hat/$', get_my_hat), ] @@ -203,6 +219,42 @@ class TestAPITestClient(TestCase): assert response.status_code == 200 assert response.data == {"flag": True} + def test_get_my_hat_with_force_authenticate(self): + user = User.objects.create_user('example', 'example@example.com', 'password') + self.client.force_authenticate(user) + + response = self.client.get('/my_hat/') + self.assertEqual(response.status_code, 404) + + hat = Hat.objects.create(user=user, some_key='some_value') + response = self.client.get('/my_hat/') + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, {'id': hat.id, 'some_key': 'some_value'}) + + another_user = User.objects.create_user('another_example', 'another_example@example.com', 'password') + hat.user = another_user + hat.save() + response = self.client.get('/my_hat/') + self.assertEqual(response.status_code, 404) + + def test_get_my_hat_with_login(self): + user = User.objects.create_user('example', 'example@example.com', 'password') + self.client.login(username='example', password='password') + + response = self.client.get('/my_hat/') + self.assertEqual(response.status_code, 404) + + hat = Hat.objects.create(user=user, some_key='some_value') + response = self.client.get('/my_hat/') + self.assertEqual(response.status_code, 200) + self.assertEqual(response.data, {'id': hat.id, 'some_key': 'some_value'}) + + another_user = User.objects.create_user('another_example', 'another_example@example.com', 'password') + hat.user = another_user + hat.save() + response = self.client.get('/my_hat/') + self.assertEqual(response.status_code, 404) + class TestAPIRequestFactory(TestCase): def test_csrf_exempt_by_default(self):