unittest compat fallback

This commit is contained in:
Tom Christie 2015-08-27 17:11:53 +01:00
parent 8264222497
commit 48540f180a
4 changed files with 15 additions and 9 deletions

View File

@ -67,6 +67,14 @@ except ImportError:
from django.utils.datastructures import SortedDict as OrderedDict from django.utils.datastructures import SortedDict as OrderedDict
# unittest.SkipUnless only available in Python 2.7.
try:
import unittest
unittest.skipUnless
except (ImportError, AttributeError):
from django.test.utils import unittest
# contrib.postgres only supported from 1.8 onwards. # contrib.postgres only supported from 1.8 onwards.
try: try:
from django.contrib.postgres import fields as postgres_fields from django.contrib.postgres import fields as postgres_fields

View File

@ -5,9 +5,9 @@ from django.db import connection, connections, transaction
from django.http import Http404 from django.http import Http404
from django.test import TestCase, TransactionTestCase from django.test import TestCase, TransactionTestCase
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.utils.unittest import skipUnless
from rest_framework import status from rest_framework import status
from rest_framework.compat import unittest
from rest_framework.exceptions import APIException from rest_framework.exceptions import APIException
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
@ -35,7 +35,7 @@ class APIExceptionView(APIView):
raise APIException raise APIException
@skipUnless(connection.features.uses_savepoints, @unittest.skipUnless(connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints.") "'atomic' requires transactions and savepoints.")
class DBTransactionTests(TestCase): class DBTransactionTests(TestCase):
def setUp(self): def setUp(self):
@ -55,7 +55,7 @@ class DBTransactionTests(TestCase):
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
@skipUnless(connection.features.uses_savepoints, @unittest.skipUnless(connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints.") "'atomic' requires transactions and savepoints.")
class DBTransactionErrorTests(TestCase): class DBTransactionErrorTests(TestCase):
def setUp(self): def setUp(self):
@ -83,7 +83,7 @@ class DBTransactionErrorTests(TestCase):
assert BasicModel.objects.count() == 1 assert BasicModel.objects.count() == 1
@skipUnless(connection.features.uses_savepoints, @unittest.skipUnless(connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints.") "'atomic' requires transactions and savepoints.")
class DBTransactionAPIExceptionTests(TestCase): class DBTransactionAPIExceptionTests(TestCase):
def setUp(self): def setUp(self):
@ -113,7 +113,7 @@ class DBTransactionAPIExceptionTests(TestCase):
assert BasicModel.objects.count() == 0 assert BasicModel.objects.count() == 0
@skipUnless(connection.features.uses_savepoints, @unittest.skipUnless(connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints.") "'atomic' requires transactions and savepoints.")
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase): class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
@property @property

View File

@ -8,12 +8,11 @@ from django.core.urlresolvers import reverse
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.test.utils import override_settings from django.test.utils import override_settings
from django.utils import unittest
from django.utils.dateparse import parse_date from django.utils.dateparse import parse_date
from django.utils.six.moves import reload_module from django.utils.six.moves import reload_module
from rest_framework import filters, generics, serializers, status from rest_framework import filters, generics, serializers, status
from rest_framework.compat import django_filters from rest_framework.compat import django_filters, unittest
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
from .models import BaseFilterableItem, BasicModel, FilterableItem from .models import BaseFilterableItem, BasicModel, FilterableItem

View File

@ -6,13 +6,12 @@ from django.contrib.auth.models import Group, Permission, User
from django.core.urlresolvers import ResolverMatch from django.core.urlresolvers import ResolverMatch
from django.db import models from django.db import models
from django.test import TestCase from django.test import TestCase
from django.utils import unittest
from rest_framework import ( from rest_framework import (
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
status status
) )
from rest_framework.compat import get_model_name, guardian from rest_framework.compat import get_model_name, guardian, unittest
from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.filters import DjangoObjectPermissionsFilter
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory