diff --git a/rest_framework/authtoken/management/commands/drf_create_token.py b/rest_framework/authtoken/management/commands/drf_create_token.py index da10bfc90..8e06812db 100644 --- a/rest_framework/authtoken/management/commands/drf_create_token.py +++ b/rest_framework/authtoken/management/commands/drf_create_token.py @@ -1,7 +1,7 @@ from django.contrib.auth import get_user_model from django.core.management.base import BaseCommand, CommandError -from rest_framework.authtoken.models import Token +from rest_framework.authtoken.models import Token UserModel = get_user_model() diff --git a/rest_framework/compat.py b/rest_framework/compat.py index 8f2b7c127..b164f09c1 100644 --- a/rest_framework/compat.py +++ b/rest_framework/compat.py @@ -78,36 +78,6 @@ def distinct(queryset, base): return queryset.distinct() -# Obtaining manager instances and names from model options differs after 1.10. -def get_names_and_managers(options): - if django.VERSION >= (1, 10): - # Django 1.10 onwards provides a `.managers` property on the Options. - return [ - (manager.name, manager) - for manager - in options.managers - ] - # For Django 1.8 and 1.9, use the three-tuple information provided - # by .concrete_managers and .abstract_managers - return [ - (manager_info[1], manager_info[2]) - for manager_info - in (options.concrete_managers + options.abstract_managers) - ] - - -# field.rel is deprecated from 1.9 onwards -def get_remote_field(field, **kwargs): - if 'default' in kwargs: - if django.VERSION < (1, 9): - return getattr(field, 'rel', kwargs['default']) - return getattr(field, 'remote_field', kwargs['default']) - - if django.VERSION < (1, 9): - return field.rel - return field.remote_field - - def _resolve_model(obj): """ Resolve supplied `obj` to a Django model class. @@ -132,44 +102,13 @@ def _resolve_model(obj): raise ValueError("{0} is not a Django model".format(obj)) -def is_authenticated(user): - if django.VERSION < (1, 10): - return user.is_authenticated() - return user.is_authenticated - - -def is_anonymous(user): - if django.VERSION < (1, 10): - return user.is_anonymous() - return user.is_anonymous - - -def get_related_model(field): - if django.VERSION < (1, 9): - return _resolve_model(field.rel.to) - return field.remote_field.model - - -def value_from_object(field, obj): - if django.VERSION < (1, 9): - return field._get_val_from_obj(obj) - return field.value_from_object(obj) - - -# contrib.postgres only supported from 1.8 onwards. +# django.contrib.postgres requires psycopg2 try: from django.contrib.postgres import fields as postgres_fields except ImportError: postgres_fields = None -# JSONField is only supported from 1.9 onwards -try: - from django.contrib.postgres.fields import JSONField -except ImportError: - JSONField = None - - # coreapi is optional (Note that uritemplate is a dependency of coreapi) try: import coreapi @@ -325,17 +264,12 @@ else: LONG_SEPARATORS = (b', ', b': ') INDENT_SEPARATORS = (b',', b': ') -try: - # DecimalValidator is unavailable in Django < 1.9 - from django.core.validators import DecimalValidator -except ImportError: - DecimalValidator = None class CustomValidatorMessage(object): """ We need to avoid evaluation of `lazy` translated `message` in `django.core.validators.BaseValidator.__init__`. https://github.com/django/django/blob/75ed5900321d170debef4ac452b8b3cf8a1c2384/django/core/validators.py#L297 - + Ref: https://github.com/encode/django-rest-framework/pull/5452 """ def __init__(self, *args, **kwargs): @@ -371,44 +305,6 @@ def set_rollback(): pass -def template_render(template, context=None, request=None): - """ - Passing Context or RequestContext to Template.render is deprecated in 1.9+, - see https://github.com/django/django/pull/3883 and - https://github.com/django/django/blob/1.9/django/template/backends/django.py#L82-L84 - - :param template: Template instance - :param context: dict - :param request: Request instance - :return: rendered template as SafeText instance - """ - if isinstance(template, Template): - if request: - context = RequestContext(request, context) - else: - context = Context(context) - return template.render(context) - # backends template, e.g. django.template.backends.django.Template - else: - return template.render(context, request=request) - - -def set_many(instance, field, value): - if django.VERSION < (1, 10): - setattr(instance, field, value) - else: - field = getattr(instance, field) - field.set(value) - - -def include(module, namespace=None, app_name=None): - from django.conf.urls import include - if django.VERSION < (1,9): - return include(module, namespace, app_name) - else: - return include((module, app_name), namespace) - - def authenticate(request=None, **credentials): from django.contrib.auth import authenticate if django.VERSION < (1, 11): diff --git a/rest_framework/fields.py b/rest_framework/fields.py index 7a79ae93c..adb002689 100644 --- a/rest_framework/fields.py +++ b/rest_framework/fields.py @@ -32,8 +32,7 @@ from django.utils.translation import ugettext_lazy as _ from rest_framework import ISO_8601 from rest_framework.compat import ( InvalidTimeError, MaxLengthValidator, MaxValueValidator, - MinLengthValidator, MinValueValidator, get_remote_field, unicode_repr, - unicode_to_repr, value_from_object + MinLengthValidator, MinValueValidator, unicode_repr, unicode_to_repr ) from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.settings import api_settings @@ -1829,7 +1828,7 @@ class ModelField(Field): MaxLengthValidator(self.max_length, message=message)) def to_internal_value(self, data): - rel = get_remote_field(self.model_field, default=None) + rel = self.model_field.remote_field if rel is not None: return rel.model._meta.get_field(rel.field_name).to_python(data) return self.model_field.to_python(data) @@ -1840,7 +1839,7 @@ class ModelField(Field): return obj def to_representation(self, obj): - value = value_from_object(self.model_field, obj) + value = self.model_field.value_from_object(obj) if is_protected_type(value): return value return self.model_field.value_to_string(obj) diff --git a/rest_framework/filters.py b/rest_framework/filters.py index 0473787bb..28b6995ec 100644 --- a/rest_framework/filters.py +++ b/rest_framework/filters.py @@ -16,9 +16,7 @@ from django.utils import six from django.utils.encoding import force_text from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import ( - coreapi, coreschema, distinct, guardian, template_render -) +from rest_framework.compat import coreapi, coreschema, distinct, guardian from rest_framework.settings import api_settings @@ -129,7 +127,7 @@ class SearchFilter(BaseFilterBackend): 'term': term } template = loader.get_template(self.template) - return template_render(template, context) + return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' @@ -260,7 +258,7 @@ class OrderingFilter(BaseFilterBackend): def to_html(self, request, queryset, view): template = loader.get_template(self.template) context = self.get_template_context(request, queryset, view) - return template_render(template, context) + return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' diff --git a/rest_framework/pagination.py b/rest_framework/pagination.py index 2c3b0d2a5..861e8cf2a 100644 --- a/rest_framework/pagination.py +++ b/rest_framework/pagination.py @@ -16,7 +16,7 @@ from django.utils.encoding import force_text from django.utils.six.moves.urllib import parse as urlparse from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import coreapi, coreschema, template_render +from rest_framework.compat import coreapi, coreschema from rest_framework.exceptions import NotFound from rest_framework.response import Response from rest_framework.settings import api_settings @@ -285,7 +285,7 @@ class PageNumberPagination(BasePagination): def to_html(self): template = loader.get_template(self.template) context = self.get_html_context() - return template_render(template, context) + return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' @@ -442,7 +442,7 @@ class LimitOffsetPagination(BasePagination): def to_html(self): template = loader.get_template(self.template) context = self.get_html_context() - return template_render(template, context) + return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' @@ -793,7 +793,7 @@ class CursorPagination(BasePagination): def to_html(self): template = loader.get_template(self.template) context = self.get_html_context() - return template_render(template, context) + return template.render(context) def get_schema_fields(self, view): assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' diff --git a/rest_framework/permissions.py b/rest_framework/permissions.py index dee0032f9..a48058e66 100644 --- a/rest_framework/permissions.py +++ b/rest_framework/permissions.py @@ -6,7 +6,6 @@ from __future__ import unicode_literals from django.http import Http404 from rest_framework import exceptions -from rest_framework.compat import is_authenticated SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS') @@ -47,7 +46,7 @@ class IsAuthenticated(BasePermission): """ def has_permission(self, request, view): - return request.user and is_authenticated(request.user) + return request.user and request.user.is_authenticated class IsAdminUser(BasePermission): @@ -68,7 +67,7 @@ class IsAuthenticatedOrReadOnly(BasePermission): return ( request.method in SAFE_METHODS or request.user and - is_authenticated(request.user) + request.user.is_authenticated ) @@ -136,7 +135,7 @@ class DjangoModelPermissions(BasePermission): return True if not request.user or ( - not is_authenticated(request.user) and self.authenticated_users_only): + not request.user.is_authenticated and self.authenticated_users_only): return False queryset = self._queryset(view) diff --git a/rest_framework/renderers.py b/rest_framework/renderers.py index 90f516b35..3bc520f53 100644 --- a/rest_framework/renderers.py +++ b/rest_framework/renderers.py @@ -16,7 +16,7 @@ from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.paginator import Page from django.http.multipartparser import parse_header -from django.template import Template, loader +from django.template import engines, loader from django.test.client import encode_multipart from django.utils import six from django.utils.html import mark_safe @@ -24,7 +24,7 @@ from django.utils.html import mark_safe from rest_framework import VERSION, exceptions, serializers, status from rest_framework.compat import ( INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, - pygments_css, template_render + pygments_css ) from rest_framework.exceptions import ParseError from rest_framework.request import is_form_media_type, override_method @@ -173,7 +173,7 @@ class TemplateHTMLRenderer(BaseRenderer): context = self.resolve_context(data, request, response) else: context = self.get_template_context(data, renderer_context) - return template_render(template, context, request=request) + return template.render(context, request=request) def resolve_template(self, template_names): return loader.select_template(template_names) @@ -206,8 +206,9 @@ class TemplateHTMLRenderer(BaseRenderer): return self.resolve_template(template_names) except Exception: # Fall back to using eg '404 Not Found' - return Template('%d %s' % (response.status_code, - response.status_text.title())) + body = '%d %s' % (response.status_code, response.status_text.title()) + template = engines['django'].from_string(body) + return template # Note, subclass TemplateHTMLRenderer simply for the exception behavior @@ -239,7 +240,7 @@ class StaticHTMLRenderer(TemplateHTMLRenderer): context = self.resolve_context(data, request, response) else: context = self.get_template_context(data, renderer_context) - return template_render(template, context, request=request) + return template.render(context, request=request) return data @@ -347,7 +348,7 @@ class HTMLFormRenderer(BaseRenderer): template = loader.get_template(template_name) context = {'field': field, 'style': style} - return template_render(template, context) + return template.render(context) def render(self, data, accepted_media_type=None, renderer_context=None): """ @@ -368,7 +369,7 @@ class HTMLFormRenderer(BaseRenderer): 'form': form, 'style': style } - return template_render(template, context) + return template.render(context) class BrowsableAPIRenderer(BaseRenderer): @@ -625,7 +626,7 @@ class BrowsableAPIRenderer(BaseRenderer): template = loader.get_template(self.filter_template) context = {'elements': elements} - return template_render(template, context) + return template.render(context) def get_context(self, data, accepted_media_type, renderer_context): """ @@ -705,7 +706,7 @@ class BrowsableAPIRenderer(BaseRenderer): template = loader.get_template(self.template) context = self.get_context(data, accepted_media_type, renderer_context) - ret = template_render(template, context, request=renderer_context['request']) + ret = template.render(context, request=renderer_context['request']) # Munge DELETE Response code to allow us to return content # (Do this *after* we've rendered the template so that we include @@ -741,7 +742,7 @@ class AdminRenderer(BrowsableAPIRenderer): template = loader.get_template(self.template) context = self.get_context(data, accepted_media_type, renderer_context) - ret = template_render(template, context, request=renderer_context['request']) + ret = template.render(context, request=renderer_context['request']) # Creation and deletion should use redirects in the admin style. if response.status_code == status.HTTP_201_CREATED and 'Location' in response: @@ -819,7 +820,7 @@ class DocumentationRenderer(BaseRenderer): def render(self, data, accepted_media_type=None, renderer_context=None): template = loader.get_template(self.template) context = self.get_context(data, renderer_context['request']) - return template_render(template, context, request=renderer_context['request']) + return template.render(context, request=renderer_context['request']) class SchemaJSRenderer(BaseRenderer): @@ -835,7 +836,7 @@ class SchemaJSRenderer(BaseRenderer): template = loader.get_template(self.template) context = {'schema': mark_safe(schema)} request = renderer_context['request'] - return template_render(template, context, request=request) + return template.render(context, request=request) class MultiPartRenderer(BaseRenderer): diff --git a/rest_framework/serializers.py b/rest_framework/serializers.py index b1c34b92a..59533be1e 100644 --- a/rest_framework/serializers.py +++ b/rest_framework/serializers.py @@ -27,8 +27,7 @@ from django.utils import six, timezone from django.utils.functional import cached_property from django.utils.translation import ugettext_lazy as _ -from rest_framework.compat import JSONField as ModelJSONField -from rest_framework.compat import postgres_fields, set_many, unicode_to_repr +from rest_framework.compat import postgres_fields, unicode_to_repr from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.fields import get_error_detail, set_value from rest_framework.settings import api_settings @@ -861,8 +860,6 @@ class ModelSerializer(Serializer): } if ModelDurationField is not None: serializer_field_mapping[ModelDurationField] = DurationField - if ModelJSONField is not None: - serializer_field_mapping[ModelJSONField] = JSONField serializer_related_field = PrimaryKeyRelatedField serializer_related_to_field = SlugRelatedField serializer_url_field = HyperlinkedIdentityField @@ -935,7 +932,8 @@ class ModelSerializer(Serializer): # Save many-to-many relationships after the instance is created. if many_to_many: for field_name, value in many_to_many.items(): - set_many(instance, field_name, value) + field = getattr(instance, field_name) + field.set(value) return instance @@ -949,7 +947,8 @@ class ModelSerializer(Serializer): # have an instance pk for the relationships to be associated with. for attr, value in validated_data.items(): if attr in info.relations and info.relations[attr].to_many: - set_many(instance, attr, value) + field = getattr(instance, attr) + field.set(value) else: setattr(instance, attr, value) instance.save() @@ -1532,6 +1531,7 @@ if postgres_fields: ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField + ModelSerializer.serializer_field_mapping[postgres_fields.JSONField] = JSONField class HyperlinkedModelSerializer(ModelSerializer): diff --git a/rest_framework/settings.py b/rest_framework/settings.py index 86b577219..d9e3a2942 100644 --- a/rest_framework/settings.py +++ b/rest_framework/settings.py @@ -19,6 +19,7 @@ REST framework settings, checking for user settings first, then falling back to the defaults. """ from __future__ import unicode_literals + from importlib import import_module from django.conf import settings diff --git a/rest_framework/templatetags/rest_framework.py b/rest_framework/templatetags/rest_framework.py index 398528dd9..20e0f7a67 100644 --- a/rest_framework/templatetags/rest_framework.py +++ b/rest_framework/templatetags/rest_framework.py @@ -11,8 +11,7 @@ from django.utils.html import escape, format_html, smart_urlquote from django.utils.safestring import SafeData, mark_safe from rest_framework.compat import ( - NoReverseMatch, apply_markdown, pygments_highlight, reverse, - template_render + NoReverseMatch, apply_markdown, pygments_highlight, reverse ) from rest_framework.renderers import HTMLFormRenderer from rest_framework.utils.urls import replace_query_param @@ -216,11 +215,11 @@ def format_value(value): else: template = loader.get_template('rest_framework/admin/simple_list_value.html') context = {'value': value} - return template_render(template, context) + return template.render(context) elif isinstance(value, dict): template = loader.get_template('rest_framework/admin/dict_value.html') context = {'value': value} - return template_render(template, context) + return template.render(context) elif isinstance(value, six.string_types): if ( (value.startswith('http:') or value.startswith('https:')) and not diff --git a/rest_framework/throttling.py b/rest_framework/throttling.py index 57f24d13f..422431566 100644 --- a/rest_framework/throttling.py +++ b/rest_framework/throttling.py @@ -8,7 +8,6 @@ import time from django.core.cache import cache as default_cache from django.core.exceptions import ImproperlyConfigured -from rest_framework.compat import is_authenticated from rest_framework.settings import api_settings @@ -174,7 +173,7 @@ class AnonRateThrottle(SimpleRateThrottle): scope = 'anon' def get_cache_key(self, request, view): - if is_authenticated(request.user): + if request.user.is_authenticated: return None # Only throttle unauthenticated requests. return self.cache_format % { @@ -194,7 +193,7 @@ class UserRateThrottle(SimpleRateThrottle): scope = 'user' def get_cache_key(self, request, view): - if is_authenticated(request.user): + if request.user.is_authenticated: ident = request.user.pk else: ident = self.get_ident(request) @@ -242,7 +241,7 @@ class ScopedRateThrottle(SimpleRateThrottle): Otherwise generate the unique cache key by concatenating the user id with the '.throttle_scope` property of the view. """ - if is_authenticated(request.user): + if request.user.is_authenticated: ident = request.user.pk else: ident = self.get_ident(request) diff --git a/rest_framework/urlpatterns.py b/rest_framework/urlpatterns.py index 2ce4ba52d..90f97f27d 100644 --- a/rest_framework/urlpatterns.py +++ b/rest_framework/urlpatterns.py @@ -1,8 +1,8 @@ from __future__ import unicode_literals -from django.conf.urls import url +from django.conf.urls import include, url -from rest_framework.compat import RegexURLResolver, include +from rest_framework.compat import RegexURLResolver from rest_framework.settings import api_settings @@ -19,8 +19,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required): patterns = apply_suffix_patterns(urlpattern.url_patterns, suffix_pattern, suffix_required) - ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) - + ret.append(url(regex, include((patterns, app_name), namespace), kwargs)) else: # Regular URL pattern regex = urlpattern.regex.pattern.rstrip('$').rstrip('/') + suffix_pattern diff --git a/rest_framework/utils/field_mapping.py b/rest_framework/utils/field_mapping.py index dff33d8b3..722981b20 100644 --- a/rest_framework/utils/field_mapping.py +++ b/rest_framework/utils/field_mapping.py @@ -8,7 +8,7 @@ from django.core import validators from django.db import models from django.utils.text import capfirst -from rest_framework.compat import DecimalValidator, JSONField +from rest_framework.compat import postgres_fields from rest_framework.validators import UniqueValidator NUMERIC_FIELD_TYPES = ( @@ -88,7 +88,7 @@ def get_field_kwargs(field_name, model_field): if decimal_places is not None: kwargs['decimal_places'] = decimal_places - if isinstance(model_field, models.TextField) or (JSONField and isinstance(model_field, JSONField)): + if isinstance(model_field, models.TextField) or (postgres_fields and isinstance(model_field, postgres_fields.JSONField)): kwargs['style'] = {'base_template': 'textarea.html'} if isinstance(model_field, models.AutoField) or not model_field.editable: @@ -181,11 +181,10 @@ def get_field_kwargs(field_name, model_field): if validator is not validators.validate_ipv46_address ] # Our decimal validation is handled in the field code, not validator code. - # (In Django 1.9+ this differs from previous style) - if isinstance(model_field, models.DecimalField) and DecimalValidator: + if isinstance(model_field, models.DecimalField): validator_kwarg = [ validator for validator in validator_kwarg - if not isinstance(validator, DecimalValidator) + if not isinstance(validator, validators.DecimalValidator) ] # Ensure that max_length is passed explicitly as a keyword arg, diff --git a/rest_framework/utils/model_meta.py b/rest_framework/utils/model_meta.py index 3e3e434e6..f0ae02bb2 100644 --- a/rest_framework/utils/model_meta.py +++ b/rest_framework/utils/model_meta.py @@ -7,8 +7,6 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance. """ from collections import OrderedDict, namedtuple -from rest_framework.compat import get_related_model, get_remote_field - FieldInfo = namedtuple('FieldResult', [ 'pk', # Model field instance 'fields', # Dict of field name -> model field instance @@ -49,19 +47,19 @@ def get_field_info(model): def _get_pk(opts): pk = opts.pk - rel = get_remote_field(pk) + rel = pk.remote_field while rel and rel.parent_link: # If model is a child via multi-table inheritance, use parent's pk. - pk = get_related_model(pk)._meta.pk - rel = get_remote_field(pk) + pk = pk.remote_field.model._meta.pk + rel = pk.remote_field return pk def _get_fields(opts): fields = OrderedDict() - for field in [field for field in opts.fields if field.serialize and not get_remote_field(field)]: + for field in [field for field in opts.fields if field.serialize and not field.remote_field]: fields[field.name] = field return fields @@ -76,10 +74,10 @@ def _get_forward_relationships(opts): Returns an `OrderedDict` of field names to `RelationInfo`. """ forward_relations = OrderedDict() - for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]: + for field in [field for field in opts.fields if field.serialize and field.remote_field]: forward_relations[field.name] = RelationInfo( model_field=field, - related_model=get_related_model(field), + related_model=field.remote_field.model, to_many=False, to_field=_get_to_field(field), has_through_model=False, @@ -90,12 +88,12 @@ def _get_forward_relationships(opts): for field in [field for field in opts.many_to_many if field.serialize]: forward_relations[field.name] = RelationInfo( model_field=field, - related_model=get_related_model(field), + related_model=field.remote_field.model, to_many=True, # manytomany do not have to_fields to_field=None, has_through_model=( - not get_remote_field(field).through._meta.auto_created + not field.remote_field.through._meta.auto_created ), reverse=False ) @@ -119,7 +117,7 @@ def _get_reverse_relationships(opts): reverse_relations[accessor_name] = RelationInfo( model_field=None, related_model=related, - to_many=get_remote_field(relation.field).multiple, + to_many=relation.field.remote_field.multiple, to_field=_get_to_field(relation.field), has_through_model=False, reverse=True @@ -137,8 +135,8 @@ def _get_reverse_relationships(opts): # manytomany do not have to_fields to_field=None, has_through_model=( - (getattr(get_remote_field(relation.field), 'through', None) is not None) and - not get_remote_field(relation.field).through._meta.auto_created + (getattr(relation.field.remote_field, 'through', None) is not None) and + not relation.field.remote_field.through._meta.auto_created ), reverse=True ) diff --git a/rest_framework/utils/representation.py b/rest_framework/utils/representation.py index 32e6d246a..deeaf1f63 100644 --- a/rest_framework/utils/representation.py +++ b/rest_framework/utils/representation.py @@ -10,13 +10,18 @@ from django.db import models from django.utils.encoding import force_text from django.utils.functional import Promise -from rest_framework.compat import get_names_and_managers, unicode_repr +from rest_framework.compat import unicode_repr def manager_repr(value): model = value.model opts = model._meta - for manager_name, manager_instance in get_names_and_managers(opts): + names_and_managers = [ + (manager.name, manager) + for manager + in opts.managers + ] + for manager_name, manager_instance in names_and_managers: if manager_instance == value: return '%s.%s.all()' % (model._meta.object_name, manager_name) return repr(value) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index fdbc28a2a..c08ea6970 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -16,11 +16,11 @@ from rest_framework import ( HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status ) from rest_framework.authentication import ( - BaseAuthentication, BasicAuthentication, RemoteUserAuthentication, SessionAuthentication, - TokenAuthentication) + BaseAuthentication, BasicAuthentication, RemoteUserAuthentication, + SessionAuthentication, TokenAuthentication +) from rest_framework.authtoken.models import Token from rest_framework.authtoken.views import obtain_auth_token -from rest_framework.compat import is_authenticated from rest_framework.response import Response from rest_framework.test import APIClient, APIRequestFactory from rest_framework.views import APIView @@ -450,7 +450,7 @@ class FailingAuthAccessedInRenderer(TestCase): def render(self, data, media_type=None, renderer_context=None): request = renderer_context['request'] - if is_authenticated(request.user): + if request.user.is_authenticated: return b'authenticated' return b'not authenticated' diff --git a/tests/test_compat.py b/tests/test_compat.py index 4c1a5e94d..aa1107617 100644 --- a/tests/test_compat.py +++ b/tests/test_compat.py @@ -22,20 +22,6 @@ class CompatTests(TestCase): expected = (timedelta.days * 86400.0) + float(timedelta.seconds) + (timedelta.microseconds / 1000000.0) assert compat.total_seconds(timedelta) == expected - def test_get_remote_field_with_old_django_version(self): - class MockField(object): - rel = 'example_rel' - compat.django.VERSION = (1, 8) - assert compat.get_remote_field(MockField(), default='default_value') == 'example_rel' - assert compat.get_remote_field(object(), default='default_value') == 'default_value' - - def test_get_remote_field_with_new_django_version(self): - class MockField(object): - remote_field = 'example_remote_field' - compat.django.VERSION = (1, 10) - assert compat.get_remote_field(MockField(), default='default_value') == 'example_remote_field' - assert compat.get_remote_field(object(), default='default_value') == 'default_value' - def test_set_rollback_for_transaction_in_managed_mode(self): class MockTransaction(object): called_rollback = False diff --git a/tests/test_htmlrenderer.py b/tests/test_htmlrenderer.py index c49fc96d4..decd25a3f 100644 --- a/tests/test_htmlrenderer.py +++ b/tests/test_htmlrenderer.py @@ -5,7 +5,7 @@ import pytest from django.conf.urls import url from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.http import Http404 -from django.template import Template, TemplateDoesNotExist +from django.template import TemplateDoesNotExist, engines from django.test import TestCase, override_settings from django.utils import six @@ -60,12 +60,12 @@ class TemplateHTMLRendererTests(TestCase): def get_template(template_name, dirs=None): if template_name == 'example.html': - return Template("example: {{ object }}") + return engines['django'].from_string("example: {{ object }}") raise TemplateDoesNotExist(template_name) def select_template(template_name_list, dirs=None, using=None): if template_name_list == ['example.html']: - return Template("example: {{ object }}") + return engines['django'].from_string("example: {{ object }}") raise TemplateDoesNotExist(template_name_list[0]) django.template.loader.get_template = get_template @@ -139,9 +139,9 @@ class TemplateHTMLRendererExceptionTests(TestCase): def get_template(template_name): if template_name == '404.html': - return Template("404: {{ detail }}") + return engines['django'].from_string("404: {{ detail }}") if template_name == '403.html': - return Template("403: {{ detail }}") + return engines['django'].from_string("403: {{ detail }}") raise TemplateDoesNotExist(template_name) django.template.loader.get_template = get_template diff --git a/tests/test_model_serializer.py b/tests/test_model_serializer.py index 3411c44b5..203e1fe7f 100644 --- a/tests/test_model_serializer.py +++ b/tests/test_model_serializer.py @@ -21,7 +21,7 @@ from django.test import TestCase from django.utils import six from rest_framework import serializers -from rest_framework.compat import set_many, unicode_repr +from rest_framework.compat import unicode_repr def dedent(blocktext): @@ -703,8 +703,7 @@ class TestIntegration(TestCase): foreign_key=self.foreign_key_target, one_to_one=self.one_to_one_target, ) - set_many(self.instance, 'many_to_many', self.many_to_many_targets) - self.instance.save() + self.instance.many_to_many.set(self.many_to_many_targets) def test_pk_retrival(self): class TestSerializer(serializers.ModelSerializer): diff --git a/tests/test_one_to_one_with_inheritance.py b/tests/test_one_to_one_with_inheritance.py index 9c489c1df..aa527a318 100644 --- a/tests/test_one_to_one_with_inheritance.py +++ b/tests/test_one_to_one_with_inheritance.py @@ -5,8 +5,6 @@ from django.test import TestCase from rest_framework import serializers from tests.models import RESTFrameworkModel - - # Models from tests.test_multitable_inheritance import ChildModel diff --git a/tests/test_permissions.py b/tests/test_permissions.py index f673c3671..7ccd43613 100644 --- a/tests/test_permissions.py +++ b/tests/test_permissions.py @@ -11,7 +11,7 @@ from rest_framework import ( HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, status, views ) -from rest_framework.compat import ResolverMatch, guardian, set_many +from rest_framework.compat import ResolverMatch, guardian from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.routers import DefaultRouter from rest_framework.test import APIRequestFactory @@ -73,13 +73,14 @@ class ModelPermissionsIntegrationTests(TestCase): def setUp(self): User.objects.create_user('disallowed', 'disallowed@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password') - set_many(user, 'user_permissions', [ + user.user_permissions.set([ Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='delete_basicmodel') ]) + user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') - set_many(user, 'user_permissions', [ + user.user_permissions.set([ Permission.objects.get(codename='change_basicmodel'), ]) diff --git a/tests/test_prefetch_related.py b/tests/test_prefetch_related.py index b87bc2f66..b07087c97 100644 --- a/tests/test_prefetch_related.py +++ b/tests/test_prefetch_related.py @@ -2,7 +2,6 @@ from django.contrib.auth.models import Group, User from django.test import TestCase from rest_framework import generics, serializers -from rest_framework.compat import set_many from rest_framework.test import APIRequestFactory factory = APIRequestFactory() @@ -23,8 +22,7 @@ class TestPrefetchRelatedUpdates(TestCase): def setUp(self): self.user = User.objects.create(username='tom', email='tom@example.com') self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] - set_many(self.user, 'groups', self.groups) - self.user.save() + self.user.groups.set(self.groups) def test_prefetch_related_updates(self): view = UserUpdate.as_view() diff --git a/tests/test_relations_pk.py b/tests/test_relations_pk.py index 2eebe1b5c..3317d6251 100644 --- a/tests/test_relations_pk.py +++ b/tests/test_relations_pk.py @@ -6,8 +6,9 @@ from django.utils import six from rest_framework import serializers from tests.models import ( ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget, - NullableForeignKeySource, NullableOneToOneSource, NullableUUIDForeignKeySource, - OneToOnePKSource, OneToOneTarget, UUIDForeignKeyTarget + NullableForeignKeySource, NullableOneToOneSource, + NullableUUIDForeignKeySource, OneToOnePKSource, OneToOneTarget, + UUIDForeignKeyTarget ) diff --git a/tests/test_request.py b/tests/test_request.py index 208d2737e..a87060df1 100644 --- a/tests/test_request.py +++ b/tests/test_request.py @@ -16,7 +16,6 @@ from django.utils import six from rest_framework import status from rest_framework.authentication import SessionAuthentication -from rest_framework.compat import is_anonymous from rest_framework.parsers import BaseParser, FormParser, MultiPartParser from rest_framework.request import Request from rest_framework.response import Response @@ -201,9 +200,9 @@ class TestUserSetter(TestCase): def test_user_can_logout(self): self.request.user = self.user - self.assertFalse(is_anonymous(self.request.user)) + self.assertFalse(self.request.user.is_anonymous) logout(self.request) - self.assertTrue(is_anonymous(self.request.user)) + self.assertTrue(self.request.user.is_anonymous) def test_logged_in_user_is_set_on_wrapped_request(self): login(self.request, self.user) diff --git a/tests/test_requests_client.py b/tests/test_requests_client.py index 791ca4ff2..161429f73 100644 --- a/tests/test_requests_client.py +++ b/tests/test_requests_client.py @@ -10,7 +10,7 @@ from django.test import override_settings from django.utils.decorators import method_decorator from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie -from rest_framework.compat import is_authenticated, requests +from rest_framework.compat import requests from rest_framework.response import Response from rest_framework.test import APITestCase, RequestsClient from rest_framework.views import APIView @@ -72,7 +72,7 @@ class SessionView(APIView): class AuthView(APIView): @method_decorator(ensure_csrf_cookie) def get(self, request): - if is_authenticated(request.user): + if request.user.is_authenticated: username = request.user.username else: username = None diff --git a/tests/test_routers.py b/tests/test_routers.py index cbc7c0554..46d54ed9f 100644 --- a/tests/test_routers.py +++ b/tests/test_routers.py @@ -3,13 +3,12 @@ from __future__ import unicode_literals from collections import namedtuple import pytest -from django.conf.urls import url +from django.conf.urls import include, url from django.core.exceptions import ImproperlyConfigured from django.db import models from django.test import TestCase, override_settings from rest_framework import permissions, serializers, viewsets -from rest_framework.compat import include from rest_framework.decorators import detail_route, list_route from rest_framework.response import Response from rest_framework.routers import DefaultRouter, SimpleRouter diff --git a/tests/test_versioning.py b/tests/test_versioning.py index 098b09b65..ab64dfab7 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -1,9 +1,8 @@ import pytest -from django.conf.urls import url +from django.conf.urls import include, url from django.test import override_settings from rest_framework import serializers, status, versioning -from rest_framework.compat import include from rest_framework.decorators import APIView from rest_framework.relations import PKOnlyObject from rest_framework.response import Response diff --git a/tests/urls.py b/tests/urls.py index a237ec219..930c1f217 100644 --- a/tests/urls.py +++ b/tests/urls.py @@ -4,6 +4,7 @@ URLConf for test suite. We need only the docs urls for DocumentationRenderer tests. """ from django.conf.urls import url + from rest_framework.documentation import include_docs_urls urlpatterns = [ diff --git a/tests/utils.py b/tests/utils.py index 5fb0723f8..0ef37016d 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,4 +1,5 @@ from django.core.exceptions import ObjectDoesNotExist + from rest_framework.compat import NoReverseMatch