Remove Django 1.8 & 1.9 compatibility code (#5481)

* Identify code that needs to be pulled out of/removed from compat.py

* Extract modern code from get_names_and_managers in compat.py and remove compat code

* Extract modern code from is_authenticated() in compat.py and remove.

* Extract modern code from is_anonymous() in compat.py and remove

* Extract modern code from get_related_model() from compat.py and remove

* Extract modern code from value_from_object() in compat.py and remove

* Update postgres compat

JSONField now always available.

* Remove DecimalValidator compat

* Remove get_remote_field compat

* Remove template_render compat

Plus isort.

* Remove set_many compat

* Remove include compat
This commit is contained in:
Carlton Gibson 2017-10-05 20:41:38 +02:00 committed by GitHub
parent 2edeb74e0e
commit c674687782
29 changed files with 95 additions and 220 deletions

View File

@ -1,7 +1,7 @@
from django.contrib.auth import get_user_model from django.contrib.auth import get_user_model
from django.core.management.base import BaseCommand, CommandError from django.core.management.base import BaseCommand, CommandError
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.models import Token
UserModel = get_user_model() UserModel = get_user_model()

View File

@ -78,36 +78,6 @@ def distinct(queryset, base):
return queryset.distinct() return queryset.distinct()
# Obtaining manager instances and names from model options differs after 1.10.
def get_names_and_managers(options):
if django.VERSION >= (1, 10):
# Django 1.10 onwards provides a `.managers` property on the Options.
return [
(manager.name, manager)
for manager
in options.managers
]
# For Django 1.8 and 1.9, use the three-tuple information provided
# by .concrete_managers and .abstract_managers
return [
(manager_info[1], manager_info[2])
for manager_info
in (options.concrete_managers + options.abstract_managers)
]
# field.rel is deprecated from 1.9 onwards
def get_remote_field(field, **kwargs):
if 'default' in kwargs:
if django.VERSION < (1, 9):
return getattr(field, 'rel', kwargs['default'])
return getattr(field, 'remote_field', kwargs['default'])
if django.VERSION < (1, 9):
return field.rel
return field.remote_field
def _resolve_model(obj): def _resolve_model(obj):
""" """
Resolve supplied `obj` to a Django model class. Resolve supplied `obj` to a Django model class.
@ -132,44 +102,13 @@ def _resolve_model(obj):
raise ValueError("{0} is not a Django model".format(obj)) raise ValueError("{0} is not a Django model".format(obj))
def is_authenticated(user): # django.contrib.postgres requires psycopg2
if django.VERSION < (1, 10):
return user.is_authenticated()
return user.is_authenticated
def is_anonymous(user):
if django.VERSION < (1, 10):
return user.is_anonymous()
return user.is_anonymous
def get_related_model(field):
if django.VERSION < (1, 9):
return _resolve_model(field.rel.to)
return field.remote_field.model
def value_from_object(field, obj):
if django.VERSION < (1, 9):
return field._get_val_from_obj(obj)
return field.value_from_object(obj)
# contrib.postgres only supported from 1.8 onwards.
try: try:
from django.contrib.postgres import fields as postgres_fields from django.contrib.postgres import fields as postgres_fields
except ImportError: except ImportError:
postgres_fields = None postgres_fields = None
# JSONField is only supported from 1.9 onwards
try:
from django.contrib.postgres.fields import JSONField
except ImportError:
JSONField = None
# coreapi is optional (Note that uritemplate is a dependency of coreapi) # coreapi is optional (Note that uritemplate is a dependency of coreapi)
try: try:
import coreapi import coreapi
@ -325,11 +264,6 @@ else:
LONG_SEPARATORS = (b', ', b': ') LONG_SEPARATORS = (b', ', b': ')
INDENT_SEPARATORS = (b',', b': ') INDENT_SEPARATORS = (b',', b': ')
try:
# DecimalValidator is unavailable in Django < 1.9
from django.core.validators import DecimalValidator
except ImportError:
DecimalValidator = None
class CustomValidatorMessage(object): class CustomValidatorMessage(object):
""" """
@ -371,44 +305,6 @@ def set_rollback():
pass pass
def template_render(template, context=None, request=None):
"""
Passing Context or RequestContext to Template.render is deprecated in 1.9+,
see https://github.com/django/django/pull/3883 and
https://github.com/django/django/blob/1.9/django/template/backends/django.py#L82-L84
:param template: Template instance
:param context: dict
:param request: Request instance
:return: rendered template as SafeText instance
"""
if isinstance(template, Template):
if request:
context = RequestContext(request, context)
else:
context = Context(context)
return template.render(context)
# backends template, e.g. django.template.backends.django.Template
else:
return template.render(context, request=request)
def set_many(instance, field, value):
if django.VERSION < (1, 10):
setattr(instance, field, value)
else:
field = getattr(instance, field)
field.set(value)
def include(module, namespace=None, app_name=None):
from django.conf.urls import include
if django.VERSION < (1,9):
return include(module, namespace, app_name)
else:
return include((module, app_name), namespace)
def authenticate(request=None, **credentials): def authenticate(request=None, **credentials):
from django.contrib.auth import authenticate from django.contrib.auth import authenticate
if django.VERSION < (1, 11): if django.VERSION < (1, 11):

View File

@ -32,8 +32,7 @@ from django.utils.translation import ugettext_lazy as _
from rest_framework import ISO_8601 from rest_framework import ISO_8601
from rest_framework.compat import ( from rest_framework.compat import (
InvalidTimeError, MaxLengthValidator, MaxValueValidator, InvalidTimeError, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, get_remote_field, unicode_repr, MinLengthValidator, MinValueValidator, unicode_repr, unicode_to_repr
unicode_to_repr, value_from_object
) )
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -1829,7 +1828,7 @@ class ModelField(Field):
MaxLengthValidator(self.max_length, message=message)) MaxLengthValidator(self.max_length, message=message))
def to_internal_value(self, data): def to_internal_value(self, data):
rel = get_remote_field(self.model_field, default=None) rel = self.model_field.remote_field
if rel is not None: if rel is not None:
return rel.model._meta.get_field(rel.field_name).to_python(data) return rel.model._meta.get_field(rel.field_name).to_python(data)
return self.model_field.to_python(data) return self.model_field.to_python(data)
@ -1840,7 +1839,7 @@ class ModelField(Field):
return obj return obj
def to_representation(self, obj): def to_representation(self, obj):
value = value_from_object(self.model_field, obj) value = self.model_field.value_from_object(obj)
if is_protected_type(value): if is_protected_type(value):
return value return value
return self.model_field.value_to_string(obj) return self.model_field.value_to_string(obj)

View File

@ -16,9 +16,7 @@ from django.utils import six
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import ( from rest_framework.compat import coreapi, coreschema, distinct, guardian
coreapi, coreschema, distinct, guardian, template_render
)
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -129,7 +127,7 @@ class SearchFilter(BaseFilterBackend):
'term': term 'term': term
} }
template = loader.get_template(self.template) template = loader.get_template(self.template)
return template_render(template, context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
@ -260,7 +258,7 @@ class OrderingFilter(BaseFilterBackend):
def to_html(self, request, queryset, view): def to_html(self, request, queryset, view):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_template_context(request, queryset, view) context = self.get_template_context(request, queryset, view)
return template_render(template, context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'

View File

@ -16,7 +16,7 @@ from django.utils.encoding import force_text
from django.utils.six.moves.urllib import parse as urlparse from django.utils.six.moves.urllib import parse as urlparse
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import coreapi, coreschema, template_render from rest_framework.compat import coreapi, coreschema
from rest_framework.exceptions import NotFound from rest_framework.exceptions import NotFound
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -285,7 +285,7 @@ class PageNumberPagination(BasePagination):
def to_html(self): def to_html(self):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_html_context() context = self.get_html_context()
return template_render(template, context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
@ -442,7 +442,7 @@ class LimitOffsetPagination(BasePagination):
def to_html(self): def to_html(self):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_html_context() context = self.get_html_context()
return template_render(template, context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
@ -793,7 +793,7 @@ class CursorPagination(BasePagination):
def to_html(self): def to_html(self):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_html_context() context = self.get_html_context()
return template_render(template, context) return template.render(context)
def get_schema_fields(self, view): def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`' assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'

View File

@ -6,7 +6,6 @@ from __future__ import unicode_literals
from django.http import Http404 from django.http import Http404
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import is_authenticated
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS') SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
@ -47,7 +46,7 @@ class IsAuthenticated(BasePermission):
""" """
def has_permission(self, request, view): def has_permission(self, request, view):
return request.user and is_authenticated(request.user) return request.user and request.user.is_authenticated
class IsAdminUser(BasePermission): class IsAdminUser(BasePermission):
@ -68,7 +67,7 @@ class IsAuthenticatedOrReadOnly(BasePermission):
return ( return (
request.method in SAFE_METHODS or request.method in SAFE_METHODS or
request.user and request.user and
is_authenticated(request.user) request.user.is_authenticated
) )
@ -136,7 +135,7 @@ class DjangoModelPermissions(BasePermission):
return True return True
if not request.user or ( if not request.user or (
not is_authenticated(request.user) and self.authenticated_users_only): not request.user.is_authenticated and self.authenticated_users_only):
return False return False
queryset = self._queryset(view) queryset = self._queryset(view)

View File

@ -16,7 +16,7 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.core.paginator import Page from django.core.paginator import Page
from django.http.multipartparser import parse_header from django.http.multipartparser import parse_header
from django.template import Template, loader from django.template import engines, loader
from django.test.client import encode_multipart from django.test.client import encode_multipart
from django.utils import six from django.utils import six
from django.utils.html import mark_safe from django.utils.html import mark_safe
@ -24,7 +24,7 @@ from django.utils.html import mark_safe
from rest_framework import VERSION, exceptions, serializers, status from rest_framework import VERSION, exceptions, serializers, status
from rest_framework.compat import ( from rest_framework.compat import (
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi, INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi,
pygments_css, template_render pygments_css
) )
from rest_framework.exceptions import ParseError from rest_framework.exceptions import ParseError
from rest_framework.request import is_form_media_type, override_method from rest_framework.request import is_form_media_type, override_method
@ -173,7 +173,7 @@ class TemplateHTMLRenderer(BaseRenderer):
context = self.resolve_context(data, request, response) context = self.resolve_context(data, request, response)
else: else:
context = self.get_template_context(data, renderer_context) context = self.get_template_context(data, renderer_context)
return template_render(template, context, request=request) return template.render(context, request=request)
def resolve_template(self, template_names): def resolve_template(self, template_names):
return loader.select_template(template_names) return loader.select_template(template_names)
@ -206,8 +206,9 @@ class TemplateHTMLRenderer(BaseRenderer):
return self.resolve_template(template_names) return self.resolve_template(template_names)
except Exception: except Exception:
# Fall back to using eg '404 Not Found' # Fall back to using eg '404 Not Found'
return Template('%d %s' % (response.status_code, body = '%d %s' % (response.status_code, response.status_text.title())
response.status_text.title())) template = engines['django'].from_string(body)
return template
# Note, subclass TemplateHTMLRenderer simply for the exception behavior # Note, subclass TemplateHTMLRenderer simply for the exception behavior
@ -239,7 +240,7 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):
context = self.resolve_context(data, request, response) context = self.resolve_context(data, request, response)
else: else:
context = self.get_template_context(data, renderer_context) context = self.get_template_context(data, renderer_context)
return template_render(template, context, request=request) return template.render(context, request=request)
return data return data
@ -347,7 +348,7 @@ class HTMLFormRenderer(BaseRenderer):
template = loader.get_template(template_name) template = loader.get_template(template_name)
context = {'field': field, 'style': style} context = {'field': field, 'style': style}
return template_render(template, context) return template.render(context)
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
""" """
@ -368,7 +369,7 @@ class HTMLFormRenderer(BaseRenderer):
'form': form, 'form': form,
'style': style 'style': style
} }
return template_render(template, context) return template.render(context)
class BrowsableAPIRenderer(BaseRenderer): class BrowsableAPIRenderer(BaseRenderer):
@ -625,7 +626,7 @@ class BrowsableAPIRenderer(BaseRenderer):
template = loader.get_template(self.filter_template) template = loader.get_template(self.filter_template)
context = {'elements': elements} context = {'elements': elements}
return template_render(template, context) return template.render(context)
def get_context(self, data, accepted_media_type, renderer_context): def get_context(self, data, accepted_media_type, renderer_context):
""" """
@ -705,7 +706,7 @@ class BrowsableAPIRenderer(BaseRenderer):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context) context = self.get_context(data, accepted_media_type, renderer_context)
ret = template_render(template, context, request=renderer_context['request']) ret = template.render(context, request=renderer_context['request'])
# Munge DELETE Response code to allow us to return content # Munge DELETE Response code to allow us to return content
# (Do this *after* we've rendered the template so that we include # (Do this *after* we've rendered the template so that we include
@ -741,7 +742,7 @@ class AdminRenderer(BrowsableAPIRenderer):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_context(data, accepted_media_type, renderer_context) context = self.get_context(data, accepted_media_type, renderer_context)
ret = template_render(template, context, request=renderer_context['request']) ret = template.render(context, request=renderer_context['request'])
# Creation and deletion should use redirects in the admin style. # Creation and deletion should use redirects in the admin style.
if response.status_code == status.HTTP_201_CREATED and 'Location' in response: if response.status_code == status.HTTP_201_CREATED and 'Location' in response:
@ -819,7 +820,7 @@ class DocumentationRenderer(BaseRenderer):
def render(self, data, accepted_media_type=None, renderer_context=None): def render(self, data, accepted_media_type=None, renderer_context=None):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = self.get_context(data, renderer_context['request']) context = self.get_context(data, renderer_context['request'])
return template_render(template, context, request=renderer_context['request']) return template.render(context, request=renderer_context['request'])
class SchemaJSRenderer(BaseRenderer): class SchemaJSRenderer(BaseRenderer):
@ -835,7 +836,7 @@ class SchemaJSRenderer(BaseRenderer):
template = loader.get_template(self.template) template = loader.get_template(self.template)
context = {'schema': mark_safe(schema)} context = {'schema': mark_safe(schema)}
request = renderer_context['request'] request = renderer_context['request']
return template_render(template, context, request=request) return template.render(context, request=request)
class MultiPartRenderer(BaseRenderer): class MultiPartRenderer(BaseRenderer):

View File

@ -27,8 +27,7 @@ from django.utils import six, timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import postgres_fields, unicode_to_repr
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.fields import get_error_detail, set_value from rest_framework.fields import get_error_detail, set_value
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -861,8 +860,6 @@ class ModelSerializer(Serializer):
} }
if ModelDurationField is not None: if ModelDurationField is not None:
serializer_field_mapping[ModelDurationField] = DurationField serializer_field_mapping[ModelDurationField] = DurationField
if ModelJSONField is not None:
serializer_field_mapping[ModelJSONField] = JSONField
serializer_related_field = PrimaryKeyRelatedField serializer_related_field = PrimaryKeyRelatedField
serializer_related_to_field = SlugRelatedField serializer_related_to_field = SlugRelatedField
serializer_url_field = HyperlinkedIdentityField serializer_url_field = HyperlinkedIdentityField
@ -935,7 +932,8 @@ class ModelSerializer(Serializer):
# Save many-to-many relationships after the instance is created. # Save many-to-many relationships after the instance is created.
if many_to_many: if many_to_many:
for field_name, value in many_to_many.items(): for field_name, value in many_to_many.items():
set_many(instance, field_name, value) field = getattr(instance, field_name)
field.set(value)
return instance return instance
@ -949,7 +947,8 @@ class ModelSerializer(Serializer):
# have an instance pk for the relationships to be associated with. # have an instance pk for the relationships to be associated with.
for attr, value in validated_data.items(): for attr, value in validated_data.items():
if attr in info.relations and info.relations[attr].to_many: if attr in info.relations and info.relations[attr].to_many:
set_many(instance, attr, value) field = getattr(instance, attr)
field.set(value)
else: else:
setattr(instance, attr, value) setattr(instance, attr, value)
instance.save() instance.save()
@ -1532,6 +1531,7 @@ if postgres_fields:
ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField
ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField
ModelSerializer.serializer_field_mapping[postgres_fields.JSONField] = JSONField
class HyperlinkedModelSerializer(ModelSerializer): class HyperlinkedModelSerializer(ModelSerializer):

View File

@ -19,6 +19,7 @@ REST framework settings, checking for user settings first, then falling
back to the defaults. back to the defaults.
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
from importlib import import_module from importlib import import_module
from django.conf import settings from django.conf import settings

View File

@ -11,8 +11,7 @@ from django.utils.html import escape, format_html, smart_urlquote
from django.utils.safestring import SafeData, mark_safe from django.utils.safestring import SafeData, mark_safe
from rest_framework.compat import ( from rest_framework.compat import (
NoReverseMatch, apply_markdown, pygments_highlight, reverse, NoReverseMatch, apply_markdown, pygments_highlight, reverse
template_render
) )
from rest_framework.renderers import HTMLFormRenderer from rest_framework.renderers import HTMLFormRenderer
from rest_framework.utils.urls import replace_query_param from rest_framework.utils.urls import replace_query_param
@ -216,11 +215,11 @@ def format_value(value):
else: else:
template = loader.get_template('rest_framework/admin/simple_list_value.html') template = loader.get_template('rest_framework/admin/simple_list_value.html')
context = {'value': value} context = {'value': value}
return template_render(template, context) return template.render(context)
elif isinstance(value, dict): elif isinstance(value, dict):
template = loader.get_template('rest_framework/admin/dict_value.html') template = loader.get_template('rest_framework/admin/dict_value.html')
context = {'value': value} context = {'value': value}
return template_render(template, context) return template.render(context)
elif isinstance(value, six.string_types): elif isinstance(value, six.string_types):
if ( if (
(value.startswith('http:') or value.startswith('https:')) and not (value.startswith('http:') or value.startswith('https:')) and not

View File

@ -8,7 +8,6 @@ import time
from django.core.cache import cache as default_cache from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework.compat import is_authenticated
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -174,7 +173,7 @@ class AnonRateThrottle(SimpleRateThrottle):
scope = 'anon' scope = 'anon'
def get_cache_key(self, request, view): def get_cache_key(self, request, view):
if is_authenticated(request.user): if request.user.is_authenticated:
return None # Only throttle unauthenticated requests. return None # Only throttle unauthenticated requests.
return self.cache_format % { return self.cache_format % {
@ -194,7 +193,7 @@ class UserRateThrottle(SimpleRateThrottle):
scope = 'user' scope = 'user'
def get_cache_key(self, request, view): def get_cache_key(self, request, view):
if is_authenticated(request.user): if request.user.is_authenticated:
ident = request.user.pk ident = request.user.pk
else: else:
ident = self.get_ident(request) ident = self.get_ident(request)
@ -242,7 +241,7 @@ class ScopedRateThrottle(SimpleRateThrottle):
Otherwise generate the unique cache key by concatenating the user id Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view. with the '.throttle_scope` property of the view.
""" """
if is_authenticated(request.user): if request.user.is_authenticated:
ident = request.user.pk ident = request.user.pk
else: else:
ident = self.get_ident(request) ident = self.get_ident(request)

View File

@ -1,8 +1,8 @@
from __future__ import unicode_literals from __future__ import unicode_literals
from django.conf.urls import url from django.conf.urls import include, url
from rest_framework.compat import RegexURLResolver, include from rest_framework.compat import RegexURLResolver
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -19,8 +19,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
patterns = apply_suffix_patterns(urlpattern.url_patterns, patterns = apply_suffix_patterns(urlpattern.url_patterns,
suffix_pattern, suffix_pattern,
suffix_required) suffix_required)
ret.append(url(regex, include(patterns, namespace, app_name), kwargs)) ret.append(url(regex, include((patterns, app_name), namespace), kwargs))
else: else:
# Regular URL pattern # Regular URL pattern
regex = urlpattern.regex.pattern.rstrip('$').rstrip('/') + suffix_pattern regex = urlpattern.regex.pattern.rstrip('$').rstrip('/') + suffix_pattern

View File

@ -8,7 +8,7 @@ from django.core import validators
from django.db import models from django.db import models
from django.utils.text import capfirst from django.utils.text import capfirst
from rest_framework.compat import DecimalValidator, JSONField from rest_framework.compat import postgres_fields
from rest_framework.validators import UniqueValidator from rest_framework.validators import UniqueValidator
NUMERIC_FIELD_TYPES = ( NUMERIC_FIELD_TYPES = (
@ -88,7 +88,7 @@ def get_field_kwargs(field_name, model_field):
if decimal_places is not None: if decimal_places is not None:
kwargs['decimal_places'] = decimal_places kwargs['decimal_places'] = decimal_places
if isinstance(model_field, models.TextField) or (JSONField and isinstance(model_field, JSONField)): if isinstance(model_field, models.TextField) or (postgres_fields and isinstance(model_field, postgres_fields.JSONField)):
kwargs['style'] = {'base_template': 'textarea.html'} kwargs['style'] = {'base_template': 'textarea.html'}
if isinstance(model_field, models.AutoField) or not model_field.editable: if isinstance(model_field, models.AutoField) or not model_field.editable:
@ -181,11 +181,10 @@ def get_field_kwargs(field_name, model_field):
if validator is not validators.validate_ipv46_address if validator is not validators.validate_ipv46_address
] ]
# Our decimal validation is handled in the field code, not validator code. # Our decimal validation is handled in the field code, not validator code.
# (In Django 1.9+ this differs from previous style) if isinstance(model_field, models.DecimalField):
if isinstance(model_field, models.DecimalField) and DecimalValidator:
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator for validator in validator_kwarg
if not isinstance(validator, DecimalValidator) if not isinstance(validator, validators.DecimalValidator)
] ]
# Ensure that max_length is passed explicitly as a keyword arg, # Ensure that max_length is passed explicitly as a keyword arg,

View File

@ -7,8 +7,6 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance.
""" """
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from rest_framework.compat import get_related_model, get_remote_field
FieldInfo = namedtuple('FieldResult', [ FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance 'pk', # Model field instance
'fields', # Dict of field name -> model field instance 'fields', # Dict of field name -> model field instance
@ -49,19 +47,19 @@ def get_field_info(model):
def _get_pk(opts): def _get_pk(opts):
pk = opts.pk pk = opts.pk
rel = get_remote_field(pk) rel = pk.remote_field
while rel and rel.parent_link: while rel and rel.parent_link:
# If model is a child via multi-table inheritance, use parent's pk. # If model is a child via multi-table inheritance, use parent's pk.
pk = get_related_model(pk)._meta.pk pk = pk.remote_field.model._meta.pk
rel = get_remote_field(pk) rel = pk.remote_field
return pk return pk
def _get_fields(opts): def _get_fields(opts):
fields = OrderedDict() fields = OrderedDict()
for field in [field for field in opts.fields if field.serialize and not get_remote_field(field)]: for field in [field for field in opts.fields if field.serialize and not field.remote_field]:
fields[field.name] = field fields[field.name] = field
return fields return fields
@ -76,10 +74,10 @@ def _get_forward_relationships(opts):
Returns an `OrderedDict` of field names to `RelationInfo`. Returns an `OrderedDict` of field names to `RelationInfo`.
""" """
forward_relations = OrderedDict() forward_relations = OrderedDict()
for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]: for field in [field for field in opts.fields if field.serialize and field.remote_field]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
related_model=get_related_model(field), related_model=field.remote_field.model,
to_many=False, to_many=False,
to_field=_get_to_field(field), to_field=_get_to_field(field),
has_through_model=False, has_through_model=False,
@ -90,12 +88,12 @@ def _get_forward_relationships(opts):
for field in [field for field in opts.many_to_many if field.serialize]: for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
related_model=get_related_model(field), related_model=field.remote_field.model,
to_many=True, to_many=True,
# manytomany do not have to_fields # manytomany do not have to_fields
to_field=None, to_field=None,
has_through_model=( has_through_model=(
not get_remote_field(field).through._meta.auto_created not field.remote_field.through._meta.auto_created
), ),
reverse=False reverse=False
) )
@ -119,7 +117,7 @@ def _get_reverse_relationships(opts):
reverse_relations[accessor_name] = RelationInfo( reverse_relations[accessor_name] = RelationInfo(
model_field=None, model_field=None,
related_model=related, related_model=related,
to_many=get_remote_field(relation.field).multiple, to_many=relation.field.remote_field.multiple,
to_field=_get_to_field(relation.field), to_field=_get_to_field(relation.field),
has_through_model=False, has_through_model=False,
reverse=True reverse=True
@ -137,8 +135,8 @@ def _get_reverse_relationships(opts):
# manytomany do not have to_fields # manytomany do not have to_fields
to_field=None, to_field=None,
has_through_model=( has_through_model=(
(getattr(get_remote_field(relation.field), 'through', None) is not None) and (getattr(relation.field.remote_field, 'through', None) is not None) and
not get_remote_field(relation.field).through._meta.auto_created not relation.field.remote_field.through._meta.auto_created
), ),
reverse=True reverse=True
) )

View File

@ -10,13 +10,18 @@ from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import Promise from django.utils.functional import Promise
from rest_framework.compat import get_names_and_managers, unicode_repr from rest_framework.compat import unicode_repr
def manager_repr(value): def manager_repr(value):
model = value.model model = value.model
opts = model._meta opts = model._meta
for manager_name, manager_instance in get_names_and_managers(opts): names_and_managers = [
(manager.name, manager)
for manager
in opts.managers
]
for manager_name, manager_instance in names_and_managers:
if manager_instance == value: if manager_instance == value:
return '%s.%s.all()' % (model._meta.object_name, manager_name) return '%s.%s.all()' % (model._meta.object_name, manager_name)
return repr(value) return repr(value)

View File

@ -16,11 +16,11 @@ from rest_framework import (
HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status
) )
from rest_framework.authentication import ( from rest_framework.authentication import (
BaseAuthentication, BasicAuthentication, RemoteUserAuthentication, SessionAuthentication, BaseAuthentication, BasicAuthentication, RemoteUserAuthentication,
TokenAuthentication) SessionAuthentication, TokenAuthentication
)
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import obtain_auth_token from rest_framework.authtoken.views import obtain_auth_token
from rest_framework.compat import is_authenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
@ -450,7 +450,7 @@ class FailingAuthAccessedInRenderer(TestCase):
def render(self, data, media_type=None, renderer_context=None): def render(self, data, media_type=None, renderer_context=None):
request = renderer_context['request'] request = renderer_context['request']
if is_authenticated(request.user): if request.user.is_authenticated:
return b'authenticated' return b'authenticated'
return b'not authenticated' return b'not authenticated'

View File

@ -22,20 +22,6 @@ class CompatTests(TestCase):
expected = (timedelta.days * 86400.0) + float(timedelta.seconds) + (timedelta.microseconds / 1000000.0) expected = (timedelta.days * 86400.0) + float(timedelta.seconds) + (timedelta.microseconds / 1000000.0)
assert compat.total_seconds(timedelta) == expected assert compat.total_seconds(timedelta) == expected
def test_get_remote_field_with_old_django_version(self):
class MockField(object):
rel = 'example_rel'
compat.django.VERSION = (1, 8)
assert compat.get_remote_field(MockField(), default='default_value') == 'example_rel'
assert compat.get_remote_field(object(), default='default_value') == 'default_value'
def test_get_remote_field_with_new_django_version(self):
class MockField(object):
remote_field = 'example_remote_field'
compat.django.VERSION = (1, 10)
assert compat.get_remote_field(MockField(), default='default_value') == 'example_remote_field'
assert compat.get_remote_field(object(), default='default_value') == 'default_value'
def test_set_rollback_for_transaction_in_managed_mode(self): def test_set_rollback_for_transaction_in_managed_mode(self):
class MockTransaction(object): class MockTransaction(object):
called_rollback = False called_rollback = False

View File

@ -5,7 +5,7 @@ import pytest
from django.conf.urls import url from django.conf.urls import url
from django.core.exceptions import ImproperlyConfigured, PermissionDenied from django.core.exceptions import ImproperlyConfigured, PermissionDenied
from django.http import Http404 from django.http import Http404
from django.template import Template, TemplateDoesNotExist from django.template import TemplateDoesNotExist, engines
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from django.utils import six from django.utils import six
@ -60,12 +60,12 @@ class TemplateHTMLRendererTests(TestCase):
def get_template(template_name, dirs=None): def get_template(template_name, dirs=None):
if template_name == 'example.html': if template_name == 'example.html':
return Template("example: {{ object }}") return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
def select_template(template_name_list, dirs=None, using=None): def select_template(template_name_list, dirs=None, using=None):
if template_name_list == ['example.html']: if template_name_list == ['example.html']:
return Template("example: {{ object }}") return engines['django'].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name_list[0]) raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template django.template.loader.get_template = get_template
@ -139,9 +139,9 @@ class TemplateHTMLRendererExceptionTests(TestCase):
def get_template(template_name): def get_template(template_name):
if template_name == '404.html': if template_name == '404.html':
return Template("404: {{ detail }}") return engines['django'].from_string("404: {{ detail }}")
if template_name == '403.html': if template_name == '403.html':
return Template("403: {{ detail }}") return engines['django'].from_string("403: {{ detail }}")
raise TemplateDoesNotExist(template_name) raise TemplateDoesNotExist(template_name)
django.template.loader.get_template = get_template django.template.loader.get_template = get_template

View File

@ -21,7 +21,7 @@ from django.test import TestCase
from django.utils import six from django.utils import six
from rest_framework import serializers from rest_framework import serializers
from rest_framework.compat import set_many, unicode_repr from rest_framework.compat import unicode_repr
def dedent(blocktext): def dedent(blocktext):
@ -703,8 +703,7 @@ class TestIntegration(TestCase):
foreign_key=self.foreign_key_target, foreign_key=self.foreign_key_target,
one_to_one=self.one_to_one_target, one_to_one=self.one_to_one_target,
) )
set_many(self.instance, 'many_to_many', self.many_to_many_targets) self.instance.many_to_many.set(self.many_to_many_targets)
self.instance.save()
def test_pk_retrival(self): def test_pk_retrival(self):
class TestSerializer(serializers.ModelSerializer): class TestSerializer(serializers.ModelSerializer):

View File

@ -5,8 +5,6 @@ from django.test import TestCase
from rest_framework import serializers from rest_framework import serializers
from tests.models import RESTFrameworkModel from tests.models import RESTFrameworkModel
# Models # Models
from tests.test_multitable_inheritance import ChildModel from tests.test_multitable_inheritance import ChildModel

View File

@ -11,7 +11,7 @@ from rest_framework import (
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers, HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
status, views status, views
) )
from rest_framework.compat import ResolverMatch, guardian, set_many from rest_framework.compat import ResolverMatch, guardian
from rest_framework.filters import DjangoObjectPermissionsFilter from rest_framework.filters import DjangoObjectPermissionsFilter
from rest_framework.routers import DefaultRouter from rest_framework.routers import DefaultRouter
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
@ -73,13 +73,14 @@ class ModelPermissionsIntegrationTests(TestCase):
def setUp(self): def setUp(self):
User.objects.create_user('disallowed', 'disallowed@example.com', 'password') User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
user = User.objects.create_user('permitted', 'permitted@example.com', 'password') user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
set_many(user, 'user_permissions', [ user.user_permissions.set([
Permission.objects.get(codename='add_basicmodel'), Permission.objects.get(codename='add_basicmodel'),
Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='change_basicmodel'),
Permission.objects.get(codename='delete_basicmodel') Permission.objects.get(codename='delete_basicmodel')
]) ])
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password') user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
set_many(user, 'user_permissions', [ user.user_permissions.set([
Permission.objects.get(codename='change_basicmodel'), Permission.objects.get(codename='change_basicmodel'),
]) ])

View File

@ -2,7 +2,6 @@ from django.contrib.auth.models import Group, User
from django.test import TestCase from django.test import TestCase
from rest_framework import generics, serializers from rest_framework import generics, serializers
from rest_framework.compat import set_many
from rest_framework.test import APIRequestFactory from rest_framework.test import APIRequestFactory
factory = APIRequestFactory() factory = APIRequestFactory()
@ -23,8 +22,7 @@ class TestPrefetchRelatedUpdates(TestCase):
def setUp(self): def setUp(self):
self.user = User.objects.create(username='tom', email='tom@example.com') self.user = User.objects.create(username='tom', email='tom@example.com')
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')] self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
set_many(self.user, 'groups', self.groups) self.user.groups.set(self.groups)
self.user.save()
def test_prefetch_related_updates(self): def test_prefetch_related_updates(self):
view = UserUpdate.as_view() view = UserUpdate.as_view()

View File

@ -6,8 +6,9 @@ from django.utils import six
from rest_framework import serializers from rest_framework import serializers
from tests.models import ( from tests.models import (
ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget, ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget,
NullableForeignKeySource, NullableOneToOneSource, NullableUUIDForeignKeySource, NullableForeignKeySource, NullableOneToOneSource,
OneToOnePKSource, OneToOneTarget, UUIDForeignKeyTarget NullableUUIDForeignKeySource, OneToOnePKSource, OneToOneTarget,
UUIDForeignKeyTarget
) )

View File

@ -16,7 +16,6 @@ from django.utils import six
from rest_framework import status from rest_framework import status
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import is_anonymous
from rest_framework.parsers import BaseParser, FormParser, MultiPartParser from rest_framework.parsers import BaseParser, FormParser, MultiPartParser
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
@ -201,9 +200,9 @@ class TestUserSetter(TestCase):
def test_user_can_logout(self): def test_user_can_logout(self):
self.request.user = self.user self.request.user = self.user
self.assertFalse(is_anonymous(self.request.user)) self.assertFalse(self.request.user.is_anonymous)
logout(self.request) logout(self.request)
self.assertTrue(is_anonymous(self.request.user)) self.assertTrue(self.request.user.is_anonymous)
def test_logged_in_user_is_set_on_wrapped_request(self): def test_logged_in_user_is_set_on_wrapped_request(self):
login(self.request, self.user) login(self.request, self.user)

View File

@ -10,7 +10,7 @@ from django.test import override_settings
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
from rest_framework.compat import is_authenticated, requests from rest_framework.compat import requests
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APITestCase, RequestsClient from rest_framework.test import APITestCase, RequestsClient
from rest_framework.views import APIView from rest_framework.views import APIView
@ -72,7 +72,7 @@ class SessionView(APIView):
class AuthView(APIView): class AuthView(APIView):
@method_decorator(ensure_csrf_cookie) @method_decorator(ensure_csrf_cookie)
def get(self, request): def get(self, request):
if is_authenticated(request.user): if request.user.is_authenticated:
username = request.user.username username = request.user.username
else: else:
username = None username = None

View File

@ -3,13 +3,12 @@ from __future__ import unicode_literals
from collections import namedtuple from collections import namedtuple
import pytest import pytest
from django.conf.urls import url from django.conf.urls import include, url
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from django.db import models from django.db import models
from django.test import TestCase, override_settings from django.test import TestCase, override_settings
from rest_framework import permissions, serializers, viewsets from rest_framework import permissions, serializers, viewsets
from rest_framework.compat import include
from rest_framework.decorators import detail_route, list_route from rest_framework.decorators import detail_route, list_route
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.routers import DefaultRouter, SimpleRouter from rest_framework.routers import DefaultRouter, SimpleRouter

View File

@ -1,9 +1,8 @@
import pytest import pytest
from django.conf.urls import url from django.conf.urls import include, url
from django.test import override_settings from django.test import override_settings
from rest_framework import serializers, status, versioning from rest_framework import serializers, status, versioning
from rest_framework.compat import include
from rest_framework.decorators import APIView from rest_framework.decorators import APIView
from rest_framework.relations import PKOnlyObject from rest_framework.relations import PKOnlyObject
from rest_framework.response import Response from rest_framework.response import Response

View File

@ -4,6 +4,7 @@ URLConf for test suite.
We need only the docs urls for DocumentationRenderer tests. We need only the docs urls for DocumentationRenderer tests.
""" """
from django.conf.urls import url from django.conf.urls import url
from rest_framework.documentation import include_docs_urls from rest_framework.documentation import include_docs_urls
urlpatterns = [ urlpatterns = [

View File

@ -1,4 +1,5 @@
from django.core.exceptions import ObjectDoesNotExist from django.core.exceptions import ObjectDoesNotExist
from rest_framework.compat import NoReverseMatch from rest_framework.compat import NoReverseMatch