diff --git a/rest_framework/test.py b/rest_framework/test.py index ebad19a4e..3b745bd62 100644 --- a/rest_framework/test.py +++ b/rest_framework/test.py @@ -5,11 +5,12 @@ from __future__ import unicode_literals import io +from importlib import import_module from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.handlers.wsgi import WSGIHandler -from django.test import testcases +from django.test import override_settings, testcases from django.test.client import Client as DjangoClient from django.test.client import RequestFactory as DjangoRequestFactory from django.test.client import ClientHandler @@ -358,3 +359,44 @@ class APISimpleTestCase(testcases.SimpleTestCase): class APILiveServerTestCase(testcases.LiveServerTestCase): client_class = APIClient + + +class URLPatternsTestCase(testcases.SimpleTestCase): + """ + Isolate URL patterns on a per-TestCase basis. For example, + + class ATestCase(URLPatternsTestCase): + urlpatterns = [...] + + def test_something(self): + ... + + class AnotherTestCase(URLPatternsTestCase): + urlpatterns = [...] + + def test_something_else(self): + ... + """ + @classmethod + def setUpClass(cls): + # Get the module of the TestCase subclass + cls._module = import_module(cls.__module__) + cls._override = override_settings(ROOT_URLCONF=cls.__module__) + + if hasattr(cls._module, 'urlpatterns'): + cls._module_urlpatterns = cls._module.urlpatterns + + cls._module.urlpatterns = cls.urlpatterns + + cls._override.enable() + super(URLPatternsTestCase, cls).setUpClass() + + @classmethod + def tearDownClass(cls): + super(URLPatternsTestCase, cls).tearDownClass() + cls._override.disable() + + if hasattr(cls, '_module_urlpatterns'): + cls._module.urlpatterns = cls._module_urlpatterns + else: + del cls._module.urlpatterns diff --git a/tests/test_testing.py b/tests/test_testing.py index 1af6ef02e..7868f724c 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -12,7 +12,7 @@ from rest_framework import fields, serializers from rest_framework.decorators import api_view from rest_framework.response import Response from rest_framework.test import ( - APIClient, APIRequestFactory, force_authenticate + APIClient, APIRequestFactory, URLPatternsTestCase, force_authenticate ) @@ -283,3 +283,30 @@ class TestAPIRequestFactory(TestCase): content_type='application/json', ) assert request.META['CONTENT_TYPE'] == 'application/json' + + +class TestUrlPatternTestCase(URLPatternsTestCase): + urlpatterns = [ + url(r'^$', view), + ] + + @classmethod + def setUpClass(cls): + assert urlpatterns is not cls.urlpatterns + super(TestUrlPatternTestCase, cls).setUpClass() + assert urlpatterns is cls.urlpatterns + + @classmethod + def tearDownClass(cls): + assert urlpatterns is cls.urlpatterns + super(TestUrlPatternTestCase, cls).tearDownClass() + assert urlpatterns is not cls.urlpatterns + + def test_urlpatterns(self): + assert self.client.get('/').status_code == 200 + + +class TestExistingPatterns(TestCase): + def test_urlpatterns(self): + # sanity test to ensure that this test module does not have a '/' route + assert self.client.get('/').status_code == 404