This commit is contained in:
Levi Payne 2017-10-05 11:40:35 +00:00 committed by GitHub
commit a1618e5aaa
11 changed files with 47 additions and 89 deletions

View File

@ -78,24 +78,7 @@ def distinct(queryset, base):
return queryset.distinct() return queryset.distinct()
# Obtaining manager instances and names from model options differs after 1.10. # TODO: Remove
def get_names_and_managers(options):
if django.VERSION >= (1, 10):
# Django 1.10 onwards provides a `.managers` property on the Options.
return [
(manager.name, manager)
for manager
in options.managers
]
# For Django 1.8 and 1.9, use the three-tuple information provided
# by .concrete_managers and .abstract_managers
return [
(manager_info[1], manager_info[2])
for manager_info
in (options.concrete_managers + options.abstract_managers)
]
# field.rel is deprecated from 1.9 onwards # field.rel is deprecated from 1.9 onwards
def get_remote_field(field, **kwargs): def get_remote_field(field, **kwargs):
if 'default' in kwargs: if 'default' in kwargs:
@ -132,44 +115,6 @@ def _resolve_model(obj):
raise ValueError("{0} is not a Django model".format(obj)) raise ValueError("{0} is not a Django model".format(obj))
def is_authenticated(user):
if django.VERSION < (1, 10):
return user.is_authenticated()
return user.is_authenticated
def is_anonymous(user):
if django.VERSION < (1, 10):
return user.is_anonymous()
return user.is_anonymous
def get_related_model(field):
if django.VERSION < (1, 9):
return _resolve_model(field.rel.to)
return field.remote_field.model
def value_from_object(field, obj):
if django.VERSION < (1, 9):
return field._get_val_from_obj(obj)
return field.value_from_object(obj)
# contrib.postgres only supported from 1.8 onwards.
try:
from django.contrib.postgres import fields as postgres_fields
except ImportError:
postgres_fields = None
# JSONField is only supported from 1.9 onwards
try:
from django.contrib.postgres.fields import JSONField
except ImportError:
JSONField = None
# coreapi is optional (Note that uritemplate is a dependency of coreapi) # coreapi is optional (Note that uritemplate is a dependency of coreapi)
try: try:
import coreapi import coreapi
@ -325,11 +270,6 @@ else:
LONG_SEPARATORS = (b', ', b': ') LONG_SEPARATORS = (b', ', b': ')
INDENT_SEPARATORS = (b',', b': ') INDENT_SEPARATORS = (b',', b': ')
try:
# DecimalValidator is unavailable in Django < 1.9
from django.core.validators import DecimalValidator
except ImportError:
DecimalValidator = None
class CustomValidatorMessage(object): class CustomValidatorMessage(object):
""" """
@ -371,6 +311,7 @@ def set_rollback():
pass pass
# TODO: Remove
def template_render(template, context=None, request=None): def template_render(template, context=None, request=None):
""" """
Passing Context or RequestContext to Template.render is deprecated in 1.9+, Passing Context or RequestContext to Template.render is deprecated in 1.9+,
@ -393,6 +334,7 @@ def template_render(template, context=None, request=None):
return template.render(context, request=request) return template.render(context, request=request)
# TODO: Remove
def set_many(instance, field, value): def set_many(instance, field, value):
if django.VERSION < (1, 10): if django.VERSION < (1, 10):
setattr(instance, field, value) setattr(instance, field, value)
@ -401,6 +343,7 @@ def set_many(instance, field, value):
field.set(value) field.set(value)
# TODO: Remove
def include(module, namespace=None, app_name=None): def include(module, namespace=None, app_name=None):
from django.conf.urls import include from django.conf.urls import include
if django.VERSION < (1,9): if django.VERSION < (1,9):

View File

@ -33,7 +33,7 @@ from rest_framework import ISO_8601
from rest_framework.compat import ( from rest_framework.compat import (
InvalidTimeError, MaxLengthValidator, MaxValueValidator, InvalidTimeError, MaxLengthValidator, MaxValueValidator,
MinLengthValidator, MinValueValidator, get_remote_field, unicode_repr, MinLengthValidator, MinValueValidator, get_remote_field, unicode_repr,
unicode_to_repr, value_from_object unicode_to_repr
) )
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -1840,7 +1840,7 @@ class ModelField(Field):
return obj return obj
def to_representation(self, obj): def to_representation(self, obj):
value = value_from_object(self.model_field, obj) value = self.model_field.value_from_object(obj)
if is_protected_type(value): if is_protected_type(value):
return value return value
return self.model_field.value_to_string(obj) return self.model_field.value_to_string(obj)

View File

@ -6,7 +6,6 @@ from __future__ import unicode_literals
from django.http import Http404 from django.http import Http404
from rest_framework import exceptions from rest_framework import exceptions
from rest_framework.compat import is_authenticated
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS') SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
@ -47,7 +46,7 @@ class IsAuthenticated(BasePermission):
""" """
def has_permission(self, request, view): def has_permission(self, request, view):
return request.user and is_authenticated(request.user) return request.user and request.user.is_authenticated
class IsAdminUser(BasePermission): class IsAdminUser(BasePermission):
@ -68,7 +67,7 @@ class IsAuthenticatedOrReadOnly(BasePermission):
return ( return (
request.method in SAFE_METHODS or request.method in SAFE_METHODS or
request.user and request.user and
is_authenticated(request.user) request.user.is_authenticated
) )
@ -136,7 +135,7 @@ class DjangoModelPermissions(BasePermission):
return True return True
if not request.user or ( if not request.user or (
not is_authenticated(request.user) and self.authenticated_users_only): not request.user.is_authenticated and self.authenticated_users_only):
return False return False
queryset = self._queryset(view) queryset = self._queryset(view)

View File

@ -27,8 +27,7 @@ from django.utils import six, timezone
from django.utils.functional import cached_property from django.utils.functional import cached_property
from django.utils.translation import ugettext_lazy as _ from django.utils.translation import ugettext_lazy as _
from rest_framework.compat import JSONField as ModelJSONField from rest_framework.compat import set_many, unicode_to_repr
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
from rest_framework.exceptions import ErrorDetail, ValidationError from rest_framework.exceptions import ErrorDetail, ValidationError
from rest_framework.fields import get_error_detail, set_value from rest_framework.fields import get_error_detail, set_value
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -46,6 +45,17 @@ from rest_framework.validators import (
UniqueTogetherValidator UniqueTogetherValidator
) )
try:
from django.contrib.postgres import fields as postgres_fields
except ImportError:
postgres_fields = None
try:
from django.contrib.postgres.fields import JSONField as ModelJSONField
except ImportError:
ModelJSONField = None
# Note: We do the following so that users of the framework can use this style: # Note: We do the following so that users of the framework can use this style:
# #
# example_field = serializers.CharField(...) # example_field = serializers.CharField(...)

View File

@ -8,7 +8,6 @@ import time
from django.core.cache import cache as default_cache from django.core.cache import cache as default_cache
from django.core.exceptions import ImproperlyConfigured from django.core.exceptions import ImproperlyConfigured
from rest_framework.compat import is_authenticated
from rest_framework.settings import api_settings from rest_framework.settings import api_settings
@ -174,7 +173,7 @@ class AnonRateThrottle(SimpleRateThrottle):
scope = 'anon' scope = 'anon'
def get_cache_key(self, request, view): def get_cache_key(self, request, view):
if is_authenticated(request.user): if request.user.is_authenticated:
return None # Only throttle unauthenticated requests. return None # Only throttle unauthenticated requests.
return self.cache_format % { return self.cache_format % {
@ -194,7 +193,7 @@ class UserRateThrottle(SimpleRateThrottle):
scope = 'user' scope = 'user'
def get_cache_key(self, request, view): def get_cache_key(self, request, view):
if is_authenticated(request.user): if request.user.is_authenticated:
ident = request.user.pk ident = request.user.pk
else: else:
ident = self.get_ident(request) ident = self.get_ident(request)
@ -242,7 +241,7 @@ class ScopedRateThrottle(SimpleRateThrottle):
Otherwise generate the unique cache key by concatenating the user id Otherwise generate the unique cache key by concatenating the user id
with the '.throttle_scope` property of the view. with the '.throttle_scope` property of the view.
""" """
if is_authenticated(request.user): if request.user.is_authenticated:
ident = request.user.pk ident = request.user.pk
else: else:
ident = self.get_ident(request) ident = self.get_ident(request)

View File

@ -8,9 +8,13 @@ from django.core import validators
from django.db import models from django.db import models
from django.utils.text import capfirst from django.utils.text import capfirst
from rest_framework.compat import DecimalValidator, JSONField
from rest_framework.validators import UniqueValidator from rest_framework.validators import UniqueValidator
try:
from django.contrib.postgres.fields import JSONField
except ImportError:
JSONField = None
NUMERIC_FIELD_TYPES = ( NUMERIC_FIELD_TYPES = (
models.IntegerField, models.FloatField, models.DecimalField models.IntegerField, models.FloatField, models.DecimalField
) )
@ -182,10 +186,10 @@ def get_field_kwargs(field_name, model_field):
] ]
# Our decimal validation is handled in the field code, not validator code. # Our decimal validation is handled in the field code, not validator code.
# (In Django 1.9+ this differs from previous style) # (In Django 1.9+ this differs from previous style)
if isinstance(model_field, models.DecimalField) and DecimalValidator: if isinstance(model_field, models.DecimalField):
validator_kwarg = [ validator_kwarg = [
validator for validator in validator_kwarg validator for validator in validator_kwarg
if not isinstance(validator, DecimalValidator) if not isinstance(validator, validators.DecimalValidator)
] ]
# Ensure that max_length is passed explicitly as a keyword arg, # Ensure that max_length is passed explicitly as a keyword arg,

View File

@ -7,7 +7,7 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance.
""" """
from collections import OrderedDict, namedtuple from collections import OrderedDict, namedtuple
from rest_framework.compat import get_related_model, get_remote_field from rest_framework.compat import get_remote_field
FieldInfo = namedtuple('FieldResult', [ FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance 'pk', # Model field instance
@ -53,7 +53,7 @@ def _get_pk(opts):
while rel and rel.parent_link: while rel and rel.parent_link:
# If model is a child via multi-table inheritance, use parent's pk. # If model is a child via multi-table inheritance, use parent's pk.
pk = get_related_model(pk)._meta.pk pk = pk.remote_field.model._meta.pk
rel = get_remote_field(pk) rel = get_remote_field(pk)
return pk return pk
@ -79,7 +79,7 @@ def _get_forward_relationships(opts):
for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]: for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
related_model=get_related_model(field), related_model=field.remote_field.model,
to_many=False, to_many=False,
to_field=_get_to_field(field), to_field=_get_to_field(field),
has_through_model=False, has_through_model=False,
@ -90,7 +90,7 @@ def _get_forward_relationships(opts):
for field in [field for field in opts.many_to_many if field.serialize]: for field in [field for field in opts.many_to_many if field.serialize]:
forward_relations[field.name] = RelationInfo( forward_relations[field.name] = RelationInfo(
model_field=field, model_field=field,
related_model=get_related_model(field), related_model=field.remote_field.model,
to_many=True, to_many=True,
# manytomany do not have to_fields # manytomany do not have to_fields
to_field=None, to_field=None,

View File

@ -10,13 +10,18 @@ from django.db import models
from django.utils.encoding import force_text from django.utils.encoding import force_text
from django.utils.functional import Promise from django.utils.functional import Promise
from rest_framework.compat import get_names_and_managers, unicode_repr from rest_framework.compat import unicode_repr
def manager_repr(value): def manager_repr(value):
model = value.model model = value.model
opts = model._meta opts = model._meta
for manager_name, manager_instance in get_names_and_managers(opts): names_and_managers = [
(manager.name, manager)
for manager
in opts.managers
]
for manager_name, manager_instance in names_and_managers:
if manager_instance == value: if manager_instance == value:
return '%s.%s.all()' % (model._meta.object_name, manager_name) return '%s.%s.all()' % (model._meta.object_name, manager_name)
return repr(value) return repr(value)

View File

@ -20,7 +20,6 @@ from rest_framework.authentication import (
TokenAuthentication) TokenAuthentication)
from rest_framework.authtoken.models import Token from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import obtain_auth_token from rest_framework.authtoken.views import obtain_auth_token
from rest_framework.compat import is_authenticated
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APIClient, APIRequestFactory from rest_framework.test import APIClient, APIRequestFactory
from rest_framework.views import APIView from rest_framework.views import APIView
@ -450,7 +449,7 @@ class FailingAuthAccessedInRenderer(TestCase):
def render(self, data, media_type=None, renderer_context=None): def render(self, data, media_type=None, renderer_context=None):
request = renderer_context['request'] request = renderer_context['request']
if is_authenticated(request.user): if request.user.is_authenticated:
return b'authenticated' return b'authenticated'
return b'not authenticated' return b'not authenticated'

View File

@ -16,7 +16,6 @@ from django.utils import six
from rest_framework import status from rest_framework import status
from rest_framework.authentication import SessionAuthentication from rest_framework.authentication import SessionAuthentication
from rest_framework.compat import is_anonymous
from rest_framework.parsers import BaseParser, FormParser, MultiPartParser from rest_framework.parsers import BaseParser, FormParser, MultiPartParser
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
@ -201,9 +200,9 @@ class TestUserSetter(TestCase):
def test_user_can_logout(self): def test_user_can_logout(self):
self.request.user = self.user self.request.user = self.user
self.assertFalse(is_anonymous(self.request.user)) self.assertFalse(self.request.user.is_anonymous)
logout(self.request) logout(self.request)
self.assertTrue(is_anonymous(self.request.user)) self.assertTrue(self.request.user.is_anonymous)
def test_logged_in_user_is_set_on_wrapped_request(self): def test_logged_in_user_is_set_on_wrapped_request(self):
login(self.request, self.user) login(self.request, self.user)

View File

@ -10,7 +10,7 @@ from django.test import override_settings
from django.utils.decorators import method_decorator from django.utils.decorators import method_decorator
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
from rest_framework.compat import is_authenticated, requests from rest_framework.compat import requests
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.test import APITestCase, RequestsClient from rest_framework.test import APITestCase, RequestsClient
from rest_framework.views import APIView from rest_framework.views import APIView
@ -72,7 +72,7 @@ class SessionView(APIView):
class AuthView(APIView): class AuthView(APIView):
@method_decorator(ensure_csrf_cookie) @method_decorator(ensure_csrf_cookie)
def get(self, request): def get(self, request):
if is_authenticated(request.user): if request.user.is_authenticated:
username = request.user.username username = request.user.username
else: else:
username = None username = None