From 48540f180a34345a6278e527cd4e494826f1b8f2 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Thu, 27 Aug 2015 17:11:53 +0100 Subject: [PATCH] unittest compat fallback --- rest_framework/compat.py | 8 ++++++++ tests/test_atomic_requests.py | 10 +++++----- tests/test_filters.py | 3 +-- tests/test_permissions.py | 3 +-- 4 files changed, 15 insertions(+), 9 deletions(-) diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 2cff61088..164cf2003 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -67,6 +67,14 @@ except ImportError: from django.utils.datastructures import SortedDict as OrderedDict +# unittest.SkipUnless only available in Python 2.7. +try: + import unittest + unittest.skipUnless +except (ImportError, AttributeError): + from django.test.utils import unittest + + # contrib.postgres only supported from 1.8 onwards. try: from django.contrib.postgres import fields as postgres_fields diff --git a/tests/test_atomic_requests.py b/tests/test_atomic_requests.py index d0d088f52..8f6830663 100644 --- a/tests/test_atomic_requests.py +++ b/tests/test_atomic_requests.py @@ -5,9 +5,9 @@ from django.db import connection, connections, transaction from django.http import Http404 from django.test import TestCase, TransactionTestCase from django.utils.decorators import method_decorator -from django.utils.unittest import skipUnless from rest_framework import status +from rest_framework.compat import unittest from rest_framework.exceptions import APIException from rest_framework.response import Response from rest_framework.test import APIRequestFactory @@ -35,7 +35,7 @@ class APIExceptionView(APIView): raise APIException -@skipUnless(connection.features.uses_savepoints, +@unittest.skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") class DBTransactionTests(TestCase): def setUp(self): @@ -55,7 +55,7 @@ class DBTransactionTests(TestCase): assert BasicModel.objects.count() == 1 -@skipUnless(connection.features.uses_savepoints, +@unittest.skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") class DBTransactionErrorTests(TestCase): def setUp(self): @@ -83,7 +83,7 @@ class DBTransactionErrorTests(TestCase): assert BasicModel.objects.count() == 1 -@skipUnless(connection.features.uses_savepoints, +@unittest.skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") class DBTransactionAPIExceptionTests(TestCase): def setUp(self): @@ -113,7 +113,7 @@ class DBTransactionAPIExceptionTests(TestCase): assert BasicModel.objects.count() == 0 -@skipUnless(connection.features.uses_savepoints, +@unittest.skipUnless(connection.features.uses_savepoints, "'atomic' requires transactions and savepoints.") class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): @property diff --git a/tests/test_filters.py b/tests/test_filters.py index 0610b0855..bce6e08fa 100644 --- a/tests/test_filters.py +++ b/tests/test_filters.py @@ -8,12 +8,11 @@ from django.core.urlresolvers import reverse from django.db import models from django.test import TestCase from django.test.utils import override_settings -from django.utils import unittest from django.utils.dateparse import parse_date from django.utils.six.moves import reload_module from rest_framework import filters, generics, serializers, status -from rest_framework.compat import django_filters +from rest_framework.compat import django_filters, unittest from rest_framework.test import APIRequestFactory from .models import BaseFilterableItem, BasicModel, FilterableItem diff --git a/tests/test_permissions.py b/tests/test_permissions.py index 398020002..ffc262a41 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -6,13 +6,12 @@ from django.contrib.auth.models import Group, Permission, User from django.core.urlresolvers import ResolverMatch from django.db import models from django.test import TestCase -from django.utils import unittest from rest_framework import ( HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status ) -from rest_framework.compat import get_model_name, guardian +from rest_framework.compat import get_model_name, guardian, unittest from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory