mirror of
https://github.com/encode/django-rest-framework.git
synced 2024-11-28 20:44:03 +03:00
Remove Django 1.8 & 1.9 compatibility code (#5481)
* Identify code that needs to be pulled out of/removed from compat.py * Extract modern code from get_names_and_managers in compat.py and remove compat code * Extract modern code from is_authenticated() in compat.py and remove. * Extract modern code from is_anonymous() in compat.py and remove * Extract modern code from get_related_model() from compat.py and remove * Extract modern code from value_from_object() in compat.py and remove * Update postgres compat JSONField now always available. * Remove DecimalValidator compat * Remove get_remote_field compat * Remove template_render compat Plus isort. * Remove set_many compat * Remove include compat
This commit is contained in:
parent
2edeb74e0e
commit
c674687782
|
@ -1,7 +1,7 @@
|
||||||
from django.contrib.auth import get_user_model
|
from django.contrib.auth import get_user_model
|
||||||
from django.core.management.base import BaseCommand, CommandError
|
from django.core.management.base import BaseCommand, CommandError
|
||||||
from rest_framework.authtoken.models import Token
|
|
||||||
|
|
||||||
|
from rest_framework.authtoken.models import Token
|
||||||
|
|
||||||
UserModel = get_user_model()
|
UserModel = get_user_model()
|
||||||
|
|
||||||
|
|
|
@ -78,36 +78,6 @@ def distinct(queryset, base):
|
||||||
return queryset.distinct()
|
return queryset.distinct()
|
||||||
|
|
||||||
|
|
||||||
# Obtaining manager instances and names from model options differs after 1.10.
|
|
||||||
def get_names_and_managers(options):
|
|
||||||
if django.VERSION >= (1, 10):
|
|
||||||
# Django 1.10 onwards provides a `.managers` property on the Options.
|
|
||||||
return [
|
|
||||||
(manager.name, manager)
|
|
||||||
for manager
|
|
||||||
in options.managers
|
|
||||||
]
|
|
||||||
# For Django 1.8 and 1.9, use the three-tuple information provided
|
|
||||||
# by .concrete_managers and .abstract_managers
|
|
||||||
return [
|
|
||||||
(manager_info[1], manager_info[2])
|
|
||||||
for manager_info
|
|
||||||
in (options.concrete_managers + options.abstract_managers)
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# field.rel is deprecated from 1.9 onwards
|
|
||||||
def get_remote_field(field, **kwargs):
|
|
||||||
if 'default' in kwargs:
|
|
||||||
if django.VERSION < (1, 9):
|
|
||||||
return getattr(field, 'rel', kwargs['default'])
|
|
||||||
return getattr(field, 'remote_field', kwargs['default'])
|
|
||||||
|
|
||||||
if django.VERSION < (1, 9):
|
|
||||||
return field.rel
|
|
||||||
return field.remote_field
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_model(obj):
|
def _resolve_model(obj):
|
||||||
"""
|
"""
|
||||||
Resolve supplied `obj` to a Django model class.
|
Resolve supplied `obj` to a Django model class.
|
||||||
|
@ -132,44 +102,13 @@ def _resolve_model(obj):
|
||||||
raise ValueError("{0} is not a Django model".format(obj))
|
raise ValueError("{0} is not a Django model".format(obj))
|
||||||
|
|
||||||
|
|
||||||
def is_authenticated(user):
|
# django.contrib.postgres requires psycopg2
|
||||||
if django.VERSION < (1, 10):
|
|
||||||
return user.is_authenticated()
|
|
||||||
return user.is_authenticated
|
|
||||||
|
|
||||||
|
|
||||||
def is_anonymous(user):
|
|
||||||
if django.VERSION < (1, 10):
|
|
||||||
return user.is_anonymous()
|
|
||||||
return user.is_anonymous
|
|
||||||
|
|
||||||
|
|
||||||
def get_related_model(field):
|
|
||||||
if django.VERSION < (1, 9):
|
|
||||||
return _resolve_model(field.rel.to)
|
|
||||||
return field.remote_field.model
|
|
||||||
|
|
||||||
|
|
||||||
def value_from_object(field, obj):
|
|
||||||
if django.VERSION < (1, 9):
|
|
||||||
return field._get_val_from_obj(obj)
|
|
||||||
return field.value_from_object(obj)
|
|
||||||
|
|
||||||
|
|
||||||
# contrib.postgres only supported from 1.8 onwards.
|
|
||||||
try:
|
try:
|
||||||
from django.contrib.postgres import fields as postgres_fields
|
from django.contrib.postgres import fields as postgres_fields
|
||||||
except ImportError:
|
except ImportError:
|
||||||
postgres_fields = None
|
postgres_fields = None
|
||||||
|
|
||||||
|
|
||||||
# JSONField is only supported from 1.9 onwards
|
|
||||||
try:
|
|
||||||
from django.contrib.postgres.fields import JSONField
|
|
||||||
except ImportError:
|
|
||||||
JSONField = None
|
|
||||||
|
|
||||||
|
|
||||||
# coreapi is optional (Note that uritemplate is a dependency of coreapi)
|
# coreapi is optional (Note that uritemplate is a dependency of coreapi)
|
||||||
try:
|
try:
|
||||||
import coreapi
|
import coreapi
|
||||||
|
@ -325,11 +264,6 @@ else:
|
||||||
LONG_SEPARATORS = (b', ', b': ')
|
LONG_SEPARATORS = (b', ', b': ')
|
||||||
INDENT_SEPARATORS = (b',', b': ')
|
INDENT_SEPARATORS = (b',', b': ')
|
||||||
|
|
||||||
try:
|
|
||||||
# DecimalValidator is unavailable in Django < 1.9
|
|
||||||
from django.core.validators import DecimalValidator
|
|
||||||
except ImportError:
|
|
||||||
DecimalValidator = None
|
|
||||||
|
|
||||||
class CustomValidatorMessage(object):
|
class CustomValidatorMessage(object):
|
||||||
"""
|
"""
|
||||||
|
@ -371,44 +305,6 @@ def set_rollback():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
def template_render(template, context=None, request=None):
|
|
||||||
"""
|
|
||||||
Passing Context or RequestContext to Template.render is deprecated in 1.9+,
|
|
||||||
see https://github.com/django/django/pull/3883 and
|
|
||||||
https://github.com/django/django/blob/1.9/django/template/backends/django.py#L82-L84
|
|
||||||
|
|
||||||
:param template: Template instance
|
|
||||||
:param context: dict
|
|
||||||
:param request: Request instance
|
|
||||||
:return: rendered template as SafeText instance
|
|
||||||
"""
|
|
||||||
if isinstance(template, Template):
|
|
||||||
if request:
|
|
||||||
context = RequestContext(request, context)
|
|
||||||
else:
|
|
||||||
context = Context(context)
|
|
||||||
return template.render(context)
|
|
||||||
# backends template, e.g. django.template.backends.django.Template
|
|
||||||
else:
|
|
||||||
return template.render(context, request=request)
|
|
||||||
|
|
||||||
|
|
||||||
def set_many(instance, field, value):
|
|
||||||
if django.VERSION < (1, 10):
|
|
||||||
setattr(instance, field, value)
|
|
||||||
else:
|
|
||||||
field = getattr(instance, field)
|
|
||||||
field.set(value)
|
|
||||||
|
|
||||||
|
|
||||||
def include(module, namespace=None, app_name=None):
|
|
||||||
from django.conf.urls import include
|
|
||||||
if django.VERSION < (1,9):
|
|
||||||
return include(module, namespace, app_name)
|
|
||||||
else:
|
|
||||||
return include((module, app_name), namespace)
|
|
||||||
|
|
||||||
|
|
||||||
def authenticate(request=None, **credentials):
|
def authenticate(request=None, **credentials):
|
||||||
from django.contrib.auth import authenticate
|
from django.contrib.auth import authenticate
|
||||||
if django.VERSION < (1, 11):
|
if django.VERSION < (1, 11):
|
||||||
|
|
|
@ -32,8 +32,7 @@ from django.utils.translation import ugettext_lazy as _
|
||||||
from rest_framework import ISO_8601
|
from rest_framework import ISO_8601
|
||||||
from rest_framework.compat import (
|
from rest_framework.compat import (
|
||||||
InvalidTimeError, MaxLengthValidator, MaxValueValidator,
|
InvalidTimeError, MaxLengthValidator, MaxValueValidator,
|
||||||
MinLengthValidator, MinValueValidator, get_remote_field, unicode_repr,
|
MinLengthValidator, MinValueValidator, unicode_repr, unicode_to_repr
|
||||||
unicode_to_repr, value_from_object
|
|
||||||
)
|
)
|
||||||
from rest_framework.exceptions import ErrorDetail, ValidationError
|
from rest_framework.exceptions import ErrorDetail, ValidationError
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
@ -1829,7 +1828,7 @@ class ModelField(Field):
|
||||||
MaxLengthValidator(self.max_length, message=message))
|
MaxLengthValidator(self.max_length, message=message))
|
||||||
|
|
||||||
def to_internal_value(self, data):
|
def to_internal_value(self, data):
|
||||||
rel = get_remote_field(self.model_field, default=None)
|
rel = self.model_field.remote_field
|
||||||
if rel is not None:
|
if rel is not None:
|
||||||
return rel.model._meta.get_field(rel.field_name).to_python(data)
|
return rel.model._meta.get_field(rel.field_name).to_python(data)
|
||||||
return self.model_field.to_python(data)
|
return self.model_field.to_python(data)
|
||||||
|
@ -1840,7 +1839,7 @@ class ModelField(Field):
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
def to_representation(self, obj):
|
def to_representation(self, obj):
|
||||||
value = value_from_object(self.model_field, obj)
|
value = self.model_field.value_from_object(obj)
|
||||||
if is_protected_type(value):
|
if is_protected_type(value):
|
||||||
return value
|
return value
|
||||||
return self.model_field.value_to_string(obj)
|
return self.model_field.value_to_string(obj)
|
||||||
|
|
|
@ -16,9 +16,7 @@ from django.utils import six
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
from rest_framework.compat import (
|
from rest_framework.compat import coreapi, coreschema, distinct, guardian
|
||||||
coreapi, coreschema, distinct, guardian, template_render
|
|
||||||
)
|
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -129,7 +127,7 @@ class SearchFilter(BaseFilterBackend):
|
||||||
'term': term
|
'term': term
|
||||||
}
|
}
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
def get_schema_fields(self, view):
|
def get_schema_fields(self, view):
|
||||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||||
|
@ -260,7 +258,7 @@ class OrderingFilter(BaseFilterBackend):
|
||||||
def to_html(self, request, queryset, view):
|
def to_html(self, request, queryset, view):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_template_context(request, queryset, view)
|
context = self.get_template_context(request, queryset, view)
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
def get_schema_fields(self, view):
|
def get_schema_fields(self, view):
|
||||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||||
|
|
|
@ -16,7 +16,7 @@ from django.utils.encoding import force_text
|
||||||
from django.utils.six.moves.urllib import parse as urlparse
|
from django.utils.six.moves.urllib import parse as urlparse
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
from rest_framework.compat import coreapi, coreschema, template_render
|
from rest_framework.compat import coreapi, coreschema
|
||||||
from rest_framework.exceptions import NotFound
|
from rest_framework.exceptions import NotFound
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
@ -285,7 +285,7 @@ class PageNumberPagination(BasePagination):
|
||||||
def to_html(self):
|
def to_html(self):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_html_context()
|
context = self.get_html_context()
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
def get_schema_fields(self, view):
|
def get_schema_fields(self, view):
|
||||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||||
|
@ -442,7 +442,7 @@ class LimitOffsetPagination(BasePagination):
|
||||||
def to_html(self):
|
def to_html(self):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_html_context()
|
context = self.get_html_context()
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
def get_schema_fields(self, view):
|
def get_schema_fields(self, view):
|
||||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||||
|
@ -793,7 +793,7 @@ class CursorPagination(BasePagination):
|
||||||
def to_html(self):
|
def to_html(self):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_html_context()
|
context = self.get_html_context()
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
def get_schema_fields(self, view):
|
def get_schema_fields(self, view):
|
||||||
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
|
||||||
|
|
|
@ -6,7 +6,6 @@ from __future__ import unicode_literals
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
|
|
||||||
from rest_framework import exceptions
|
from rest_framework import exceptions
|
||||||
from rest_framework.compat import is_authenticated
|
|
||||||
|
|
||||||
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
|
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
|
||||||
|
|
||||||
|
@ -47,7 +46,7 @@ class IsAuthenticated(BasePermission):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def has_permission(self, request, view):
|
def has_permission(self, request, view):
|
||||||
return request.user and is_authenticated(request.user)
|
return request.user and request.user.is_authenticated
|
||||||
|
|
||||||
|
|
||||||
class IsAdminUser(BasePermission):
|
class IsAdminUser(BasePermission):
|
||||||
|
@ -68,7 +67,7 @@ class IsAuthenticatedOrReadOnly(BasePermission):
|
||||||
return (
|
return (
|
||||||
request.method in SAFE_METHODS or
|
request.method in SAFE_METHODS or
|
||||||
request.user and
|
request.user and
|
||||||
is_authenticated(request.user)
|
request.user.is_authenticated
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,7 +135,7 @@ class DjangoModelPermissions(BasePermission):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if not request.user or (
|
if not request.user or (
|
||||||
not is_authenticated(request.user) and self.authenticated_users_only):
|
not request.user.is_authenticated and self.authenticated_users_only):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
queryset = self._queryset(view)
|
queryset = self._queryset(view)
|
||||||
|
|
|
@ -16,7 +16,7 @@ from django.conf import settings
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.core.paginator import Page
|
from django.core.paginator import Page
|
||||||
from django.http.multipartparser import parse_header
|
from django.http.multipartparser import parse_header
|
||||||
from django.template import Template, loader
|
from django.template import engines, loader
|
||||||
from django.test.client import encode_multipart
|
from django.test.client import encode_multipart
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
from django.utils.html import mark_safe
|
from django.utils.html import mark_safe
|
||||||
|
@ -24,7 +24,7 @@ from django.utils.html import mark_safe
|
||||||
from rest_framework import VERSION, exceptions, serializers, status
|
from rest_framework import VERSION, exceptions, serializers, status
|
||||||
from rest_framework.compat import (
|
from rest_framework.compat import (
|
||||||
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi,
|
INDENT_SEPARATORS, LONG_SEPARATORS, SHORT_SEPARATORS, coreapi,
|
||||||
pygments_css, template_render
|
pygments_css
|
||||||
)
|
)
|
||||||
from rest_framework.exceptions import ParseError
|
from rest_framework.exceptions import ParseError
|
||||||
from rest_framework.request import is_form_media_type, override_method
|
from rest_framework.request import is_form_media_type, override_method
|
||||||
|
@ -173,7 +173,7 @@ class TemplateHTMLRenderer(BaseRenderer):
|
||||||
context = self.resolve_context(data, request, response)
|
context = self.resolve_context(data, request, response)
|
||||||
else:
|
else:
|
||||||
context = self.get_template_context(data, renderer_context)
|
context = self.get_template_context(data, renderer_context)
|
||||||
return template_render(template, context, request=request)
|
return template.render(context, request=request)
|
||||||
|
|
||||||
def resolve_template(self, template_names):
|
def resolve_template(self, template_names):
|
||||||
return loader.select_template(template_names)
|
return loader.select_template(template_names)
|
||||||
|
@ -206,8 +206,9 @@ class TemplateHTMLRenderer(BaseRenderer):
|
||||||
return self.resolve_template(template_names)
|
return self.resolve_template(template_names)
|
||||||
except Exception:
|
except Exception:
|
||||||
# Fall back to using eg '404 Not Found'
|
# Fall back to using eg '404 Not Found'
|
||||||
return Template('%d %s' % (response.status_code,
|
body = '%d %s' % (response.status_code, response.status_text.title())
|
||||||
response.status_text.title()))
|
template = engines['django'].from_string(body)
|
||||||
|
return template
|
||||||
|
|
||||||
|
|
||||||
# Note, subclass TemplateHTMLRenderer simply for the exception behavior
|
# Note, subclass TemplateHTMLRenderer simply for the exception behavior
|
||||||
|
@ -239,7 +240,7 @@ class StaticHTMLRenderer(TemplateHTMLRenderer):
|
||||||
context = self.resolve_context(data, request, response)
|
context = self.resolve_context(data, request, response)
|
||||||
else:
|
else:
|
||||||
context = self.get_template_context(data, renderer_context)
|
context = self.get_template_context(data, renderer_context)
|
||||||
return template_render(template, context, request=request)
|
return template.render(context, request=request)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@ -347,7 +348,7 @@ class HTMLFormRenderer(BaseRenderer):
|
||||||
|
|
||||||
template = loader.get_template(template_name)
|
template = loader.get_template(template_name)
|
||||||
context = {'field': field, 'style': style}
|
context = {'field': field, 'style': style}
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||||
"""
|
"""
|
||||||
|
@ -368,7 +369,7 @@ class HTMLFormRenderer(BaseRenderer):
|
||||||
'form': form,
|
'form': form,
|
||||||
'style': style
|
'style': style
|
||||||
}
|
}
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
|
|
||||||
class BrowsableAPIRenderer(BaseRenderer):
|
class BrowsableAPIRenderer(BaseRenderer):
|
||||||
|
@ -625,7 +626,7 @@ class BrowsableAPIRenderer(BaseRenderer):
|
||||||
|
|
||||||
template = loader.get_template(self.filter_template)
|
template = loader.get_template(self.filter_template)
|
||||||
context = {'elements': elements}
|
context = {'elements': elements}
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
|
|
||||||
def get_context(self, data, accepted_media_type, renderer_context):
|
def get_context(self, data, accepted_media_type, renderer_context):
|
||||||
"""
|
"""
|
||||||
|
@ -705,7 +706,7 @@ class BrowsableAPIRenderer(BaseRenderer):
|
||||||
|
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_context(data, accepted_media_type, renderer_context)
|
context = self.get_context(data, accepted_media_type, renderer_context)
|
||||||
ret = template_render(template, context, request=renderer_context['request'])
|
ret = template.render(context, request=renderer_context['request'])
|
||||||
|
|
||||||
# Munge DELETE Response code to allow us to return content
|
# Munge DELETE Response code to allow us to return content
|
||||||
# (Do this *after* we've rendered the template so that we include
|
# (Do this *after* we've rendered the template so that we include
|
||||||
|
@ -741,7 +742,7 @@ class AdminRenderer(BrowsableAPIRenderer):
|
||||||
|
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_context(data, accepted_media_type, renderer_context)
|
context = self.get_context(data, accepted_media_type, renderer_context)
|
||||||
ret = template_render(template, context, request=renderer_context['request'])
|
ret = template.render(context, request=renderer_context['request'])
|
||||||
|
|
||||||
# Creation and deletion should use redirects in the admin style.
|
# Creation and deletion should use redirects in the admin style.
|
||||||
if response.status_code == status.HTTP_201_CREATED and 'Location' in response:
|
if response.status_code == status.HTTP_201_CREATED and 'Location' in response:
|
||||||
|
@ -819,7 +820,7 @@ class DocumentationRenderer(BaseRenderer):
|
||||||
def render(self, data, accepted_media_type=None, renderer_context=None):
|
def render(self, data, accepted_media_type=None, renderer_context=None):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = self.get_context(data, renderer_context['request'])
|
context = self.get_context(data, renderer_context['request'])
|
||||||
return template_render(template, context, request=renderer_context['request'])
|
return template.render(context, request=renderer_context['request'])
|
||||||
|
|
||||||
|
|
||||||
class SchemaJSRenderer(BaseRenderer):
|
class SchemaJSRenderer(BaseRenderer):
|
||||||
|
@ -835,7 +836,7 @@ class SchemaJSRenderer(BaseRenderer):
|
||||||
template = loader.get_template(self.template)
|
template = loader.get_template(self.template)
|
||||||
context = {'schema': mark_safe(schema)}
|
context = {'schema': mark_safe(schema)}
|
||||||
request = renderer_context['request']
|
request = renderer_context['request']
|
||||||
return template_render(template, context, request=request)
|
return template.render(context, request=request)
|
||||||
|
|
||||||
|
|
||||||
class MultiPartRenderer(BaseRenderer):
|
class MultiPartRenderer(BaseRenderer):
|
||||||
|
|
|
@ -27,8 +27,7 @@ from django.utils import six, timezone
|
||||||
from django.utils.functional import cached_property
|
from django.utils.functional import cached_property
|
||||||
from django.utils.translation import ugettext_lazy as _
|
from django.utils.translation import ugettext_lazy as _
|
||||||
|
|
||||||
from rest_framework.compat import JSONField as ModelJSONField
|
from rest_framework.compat import postgres_fields, unicode_to_repr
|
||||||
from rest_framework.compat import postgres_fields, set_many, unicode_to_repr
|
|
||||||
from rest_framework.exceptions import ErrorDetail, ValidationError
|
from rest_framework.exceptions import ErrorDetail, ValidationError
|
||||||
from rest_framework.fields import get_error_detail, set_value
|
from rest_framework.fields import get_error_detail, set_value
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
@ -861,8 +860,6 @@ class ModelSerializer(Serializer):
|
||||||
}
|
}
|
||||||
if ModelDurationField is not None:
|
if ModelDurationField is not None:
|
||||||
serializer_field_mapping[ModelDurationField] = DurationField
|
serializer_field_mapping[ModelDurationField] = DurationField
|
||||||
if ModelJSONField is not None:
|
|
||||||
serializer_field_mapping[ModelJSONField] = JSONField
|
|
||||||
serializer_related_field = PrimaryKeyRelatedField
|
serializer_related_field = PrimaryKeyRelatedField
|
||||||
serializer_related_to_field = SlugRelatedField
|
serializer_related_to_field = SlugRelatedField
|
||||||
serializer_url_field = HyperlinkedIdentityField
|
serializer_url_field = HyperlinkedIdentityField
|
||||||
|
@ -935,7 +932,8 @@ class ModelSerializer(Serializer):
|
||||||
# Save many-to-many relationships after the instance is created.
|
# Save many-to-many relationships after the instance is created.
|
||||||
if many_to_many:
|
if many_to_many:
|
||||||
for field_name, value in many_to_many.items():
|
for field_name, value in many_to_many.items():
|
||||||
set_many(instance, field_name, value)
|
field = getattr(instance, field_name)
|
||||||
|
field.set(value)
|
||||||
|
|
||||||
return instance
|
return instance
|
||||||
|
|
||||||
|
@ -949,7 +947,8 @@ class ModelSerializer(Serializer):
|
||||||
# have an instance pk for the relationships to be associated with.
|
# have an instance pk for the relationships to be associated with.
|
||||||
for attr, value in validated_data.items():
|
for attr, value in validated_data.items():
|
||||||
if attr in info.relations and info.relations[attr].to_many:
|
if attr in info.relations and info.relations[attr].to_many:
|
||||||
set_many(instance, attr, value)
|
field = getattr(instance, attr)
|
||||||
|
field.set(value)
|
||||||
else:
|
else:
|
||||||
setattr(instance, attr, value)
|
setattr(instance, attr, value)
|
||||||
instance.save()
|
instance.save()
|
||||||
|
@ -1532,6 +1531,7 @@ if postgres_fields:
|
||||||
|
|
||||||
ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField
|
ModelSerializer.serializer_field_mapping[postgres_fields.HStoreField] = CharMappingField
|
||||||
ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField
|
ModelSerializer.serializer_field_mapping[postgres_fields.ArrayField] = ListField
|
||||||
|
ModelSerializer.serializer_field_mapping[postgres_fields.JSONField] = JSONField
|
||||||
|
|
||||||
|
|
||||||
class HyperlinkedModelSerializer(ModelSerializer):
|
class HyperlinkedModelSerializer(ModelSerializer):
|
||||||
|
|
|
@ -19,6 +19,7 @@ REST framework settings, checking for user settings first, then falling
|
||||||
back to the defaults.
|
back to the defaults.
|
||||||
"""
|
"""
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from importlib import import_module
|
from importlib import import_module
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
|
|
@ -11,8 +11,7 @@ from django.utils.html import escape, format_html, smart_urlquote
|
||||||
from django.utils.safestring import SafeData, mark_safe
|
from django.utils.safestring import SafeData, mark_safe
|
||||||
|
|
||||||
from rest_framework.compat import (
|
from rest_framework.compat import (
|
||||||
NoReverseMatch, apply_markdown, pygments_highlight, reverse,
|
NoReverseMatch, apply_markdown, pygments_highlight, reverse
|
||||||
template_render
|
|
||||||
)
|
)
|
||||||
from rest_framework.renderers import HTMLFormRenderer
|
from rest_framework.renderers import HTMLFormRenderer
|
||||||
from rest_framework.utils.urls import replace_query_param
|
from rest_framework.utils.urls import replace_query_param
|
||||||
|
@ -216,11 +215,11 @@ def format_value(value):
|
||||||
else:
|
else:
|
||||||
template = loader.get_template('rest_framework/admin/simple_list_value.html')
|
template = loader.get_template('rest_framework/admin/simple_list_value.html')
|
||||||
context = {'value': value}
|
context = {'value': value}
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
template = loader.get_template('rest_framework/admin/dict_value.html')
|
template = loader.get_template('rest_framework/admin/dict_value.html')
|
||||||
context = {'value': value}
|
context = {'value': value}
|
||||||
return template_render(template, context)
|
return template.render(context)
|
||||||
elif isinstance(value, six.string_types):
|
elif isinstance(value, six.string_types):
|
||||||
if (
|
if (
|
||||||
(value.startswith('http:') or value.startswith('https:')) and not
|
(value.startswith('http:') or value.startswith('https:')) and not
|
||||||
|
|
|
@ -8,7 +8,6 @@ import time
|
||||||
from django.core.cache import cache as default_cache
|
from django.core.cache import cache as default_cache
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
|
|
||||||
from rest_framework.compat import is_authenticated
|
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -174,7 +173,7 @@ class AnonRateThrottle(SimpleRateThrottle):
|
||||||
scope = 'anon'
|
scope = 'anon'
|
||||||
|
|
||||||
def get_cache_key(self, request, view):
|
def get_cache_key(self, request, view):
|
||||||
if is_authenticated(request.user):
|
if request.user.is_authenticated:
|
||||||
return None # Only throttle unauthenticated requests.
|
return None # Only throttle unauthenticated requests.
|
||||||
|
|
||||||
return self.cache_format % {
|
return self.cache_format % {
|
||||||
|
@ -194,7 +193,7 @@ class UserRateThrottle(SimpleRateThrottle):
|
||||||
scope = 'user'
|
scope = 'user'
|
||||||
|
|
||||||
def get_cache_key(self, request, view):
|
def get_cache_key(self, request, view):
|
||||||
if is_authenticated(request.user):
|
if request.user.is_authenticated:
|
||||||
ident = request.user.pk
|
ident = request.user.pk
|
||||||
else:
|
else:
|
||||||
ident = self.get_ident(request)
|
ident = self.get_ident(request)
|
||||||
|
@ -242,7 +241,7 @@ class ScopedRateThrottle(SimpleRateThrottle):
|
||||||
Otherwise generate the unique cache key by concatenating the user id
|
Otherwise generate the unique cache key by concatenating the user id
|
||||||
with the '.throttle_scope` property of the view.
|
with the '.throttle_scope` property of the view.
|
||||||
"""
|
"""
|
||||||
if is_authenticated(request.user):
|
if request.user.is_authenticated:
|
||||||
ident = request.user.pk
|
ident = request.user.pk
|
||||||
else:
|
else:
|
||||||
ident = self.get_ident(request)
|
ident = self.get_ident(request)
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from django.conf.urls import url
|
from django.conf.urls import include, url
|
||||||
|
|
||||||
from rest_framework.compat import RegexURLResolver, include
|
from rest_framework.compat import RegexURLResolver
|
||||||
from rest_framework.settings import api_settings
|
from rest_framework.settings import api_settings
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,8 +19,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required):
|
||||||
patterns = apply_suffix_patterns(urlpattern.url_patterns,
|
patterns = apply_suffix_patterns(urlpattern.url_patterns,
|
||||||
suffix_pattern,
|
suffix_pattern,
|
||||||
suffix_required)
|
suffix_required)
|
||||||
ret.append(url(regex, include(patterns, namespace, app_name), kwargs))
|
ret.append(url(regex, include((patterns, app_name), namespace), kwargs))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Regular URL pattern
|
# Regular URL pattern
|
||||||
regex = urlpattern.regex.pattern.rstrip('$').rstrip('/') + suffix_pattern
|
regex = urlpattern.regex.pattern.rstrip('$').rstrip('/') + suffix_pattern
|
||||||
|
|
|
@ -8,7 +8,7 @@ from django.core import validators
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.utils.text import capfirst
|
from django.utils.text import capfirst
|
||||||
|
|
||||||
from rest_framework.compat import DecimalValidator, JSONField
|
from rest_framework.compat import postgres_fields
|
||||||
from rest_framework.validators import UniqueValidator
|
from rest_framework.validators import UniqueValidator
|
||||||
|
|
||||||
NUMERIC_FIELD_TYPES = (
|
NUMERIC_FIELD_TYPES = (
|
||||||
|
@ -88,7 +88,7 @@ def get_field_kwargs(field_name, model_field):
|
||||||
if decimal_places is not None:
|
if decimal_places is not None:
|
||||||
kwargs['decimal_places'] = decimal_places
|
kwargs['decimal_places'] = decimal_places
|
||||||
|
|
||||||
if isinstance(model_field, models.TextField) or (JSONField and isinstance(model_field, JSONField)):
|
if isinstance(model_field, models.TextField) or (postgres_fields and isinstance(model_field, postgres_fields.JSONField)):
|
||||||
kwargs['style'] = {'base_template': 'textarea.html'}
|
kwargs['style'] = {'base_template': 'textarea.html'}
|
||||||
|
|
||||||
if isinstance(model_field, models.AutoField) or not model_field.editable:
|
if isinstance(model_field, models.AutoField) or not model_field.editable:
|
||||||
|
@ -181,11 +181,10 @@ def get_field_kwargs(field_name, model_field):
|
||||||
if validator is not validators.validate_ipv46_address
|
if validator is not validators.validate_ipv46_address
|
||||||
]
|
]
|
||||||
# Our decimal validation is handled in the field code, not validator code.
|
# Our decimal validation is handled in the field code, not validator code.
|
||||||
# (In Django 1.9+ this differs from previous style)
|
if isinstance(model_field, models.DecimalField):
|
||||||
if isinstance(model_field, models.DecimalField) and DecimalValidator:
|
|
||||||
validator_kwarg = [
|
validator_kwarg = [
|
||||||
validator for validator in validator_kwarg
|
validator for validator in validator_kwarg
|
||||||
if not isinstance(validator, DecimalValidator)
|
if not isinstance(validator, validators.DecimalValidator)
|
||||||
]
|
]
|
||||||
|
|
||||||
# Ensure that max_length is passed explicitly as a keyword arg,
|
# Ensure that max_length is passed explicitly as a keyword arg,
|
||||||
|
|
|
@ -7,8 +7,6 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance.
|
||||||
"""
|
"""
|
||||||
from collections import OrderedDict, namedtuple
|
from collections import OrderedDict, namedtuple
|
||||||
|
|
||||||
from rest_framework.compat import get_related_model, get_remote_field
|
|
||||||
|
|
||||||
FieldInfo = namedtuple('FieldResult', [
|
FieldInfo = namedtuple('FieldResult', [
|
||||||
'pk', # Model field instance
|
'pk', # Model field instance
|
||||||
'fields', # Dict of field name -> model field instance
|
'fields', # Dict of field name -> model field instance
|
||||||
|
@ -49,19 +47,19 @@ def get_field_info(model):
|
||||||
|
|
||||||
def _get_pk(opts):
|
def _get_pk(opts):
|
||||||
pk = opts.pk
|
pk = opts.pk
|
||||||
rel = get_remote_field(pk)
|
rel = pk.remote_field
|
||||||
|
|
||||||
while rel and rel.parent_link:
|
while rel and rel.parent_link:
|
||||||
# If model is a child via multi-table inheritance, use parent's pk.
|
# If model is a child via multi-table inheritance, use parent's pk.
|
||||||
pk = get_related_model(pk)._meta.pk
|
pk = pk.remote_field.model._meta.pk
|
||||||
rel = get_remote_field(pk)
|
rel = pk.remote_field
|
||||||
|
|
||||||
return pk
|
return pk
|
||||||
|
|
||||||
|
|
||||||
def _get_fields(opts):
|
def _get_fields(opts):
|
||||||
fields = OrderedDict()
|
fields = OrderedDict()
|
||||||
for field in [field for field in opts.fields if field.serialize and not get_remote_field(field)]:
|
for field in [field for field in opts.fields if field.serialize and not field.remote_field]:
|
||||||
fields[field.name] = field
|
fields[field.name] = field
|
||||||
|
|
||||||
return fields
|
return fields
|
||||||
|
@ -76,10 +74,10 @@ def _get_forward_relationships(opts):
|
||||||
Returns an `OrderedDict` of field names to `RelationInfo`.
|
Returns an `OrderedDict` of field names to `RelationInfo`.
|
||||||
"""
|
"""
|
||||||
forward_relations = OrderedDict()
|
forward_relations = OrderedDict()
|
||||||
for field in [field for field in opts.fields if field.serialize and get_remote_field(field)]:
|
for field in [field for field in opts.fields if field.serialize and field.remote_field]:
|
||||||
forward_relations[field.name] = RelationInfo(
|
forward_relations[field.name] = RelationInfo(
|
||||||
model_field=field,
|
model_field=field,
|
||||||
related_model=get_related_model(field),
|
related_model=field.remote_field.model,
|
||||||
to_many=False,
|
to_many=False,
|
||||||
to_field=_get_to_field(field),
|
to_field=_get_to_field(field),
|
||||||
has_through_model=False,
|
has_through_model=False,
|
||||||
|
@ -90,12 +88,12 @@ def _get_forward_relationships(opts):
|
||||||
for field in [field for field in opts.many_to_many if field.serialize]:
|
for field in [field for field in opts.many_to_many if field.serialize]:
|
||||||
forward_relations[field.name] = RelationInfo(
|
forward_relations[field.name] = RelationInfo(
|
||||||
model_field=field,
|
model_field=field,
|
||||||
related_model=get_related_model(field),
|
related_model=field.remote_field.model,
|
||||||
to_many=True,
|
to_many=True,
|
||||||
# manytomany do not have to_fields
|
# manytomany do not have to_fields
|
||||||
to_field=None,
|
to_field=None,
|
||||||
has_through_model=(
|
has_through_model=(
|
||||||
not get_remote_field(field).through._meta.auto_created
|
not field.remote_field.through._meta.auto_created
|
||||||
),
|
),
|
||||||
reverse=False
|
reverse=False
|
||||||
)
|
)
|
||||||
|
@ -119,7 +117,7 @@ def _get_reverse_relationships(opts):
|
||||||
reverse_relations[accessor_name] = RelationInfo(
|
reverse_relations[accessor_name] = RelationInfo(
|
||||||
model_field=None,
|
model_field=None,
|
||||||
related_model=related,
|
related_model=related,
|
||||||
to_many=get_remote_field(relation.field).multiple,
|
to_many=relation.field.remote_field.multiple,
|
||||||
to_field=_get_to_field(relation.field),
|
to_field=_get_to_field(relation.field),
|
||||||
has_through_model=False,
|
has_through_model=False,
|
||||||
reverse=True
|
reverse=True
|
||||||
|
@ -137,8 +135,8 @@ def _get_reverse_relationships(opts):
|
||||||
# manytomany do not have to_fields
|
# manytomany do not have to_fields
|
||||||
to_field=None,
|
to_field=None,
|
||||||
has_through_model=(
|
has_through_model=(
|
||||||
(getattr(get_remote_field(relation.field), 'through', None) is not None) and
|
(getattr(relation.field.remote_field, 'through', None) is not None) and
|
||||||
not get_remote_field(relation.field).through._meta.auto_created
|
not relation.field.remote_field.through._meta.auto_created
|
||||||
),
|
),
|
||||||
reverse=True
|
reverse=True
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,13 +10,18 @@ from django.db import models
|
||||||
from django.utils.encoding import force_text
|
from django.utils.encoding import force_text
|
||||||
from django.utils.functional import Promise
|
from django.utils.functional import Promise
|
||||||
|
|
||||||
from rest_framework.compat import get_names_and_managers, unicode_repr
|
from rest_framework.compat import unicode_repr
|
||||||
|
|
||||||
|
|
||||||
def manager_repr(value):
|
def manager_repr(value):
|
||||||
model = value.model
|
model = value.model
|
||||||
opts = model._meta
|
opts = model._meta
|
||||||
for manager_name, manager_instance in get_names_and_managers(opts):
|
names_and_managers = [
|
||||||
|
(manager.name, manager)
|
||||||
|
for manager
|
||||||
|
in opts.managers
|
||||||
|
]
|
||||||
|
for manager_name, manager_instance in names_and_managers:
|
||||||
if manager_instance == value:
|
if manager_instance == value:
|
||||||
return '%s.%s.all()' % (model._meta.object_name, manager_name)
|
return '%s.%s.all()' % (model._meta.object_name, manager_name)
|
||||||
return repr(value)
|
return repr(value)
|
||||||
|
|
|
@ -16,11 +16,11 @@ from rest_framework import (
|
||||||
HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status
|
HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status
|
||||||
)
|
)
|
||||||
from rest_framework.authentication import (
|
from rest_framework.authentication import (
|
||||||
BaseAuthentication, BasicAuthentication, RemoteUserAuthentication, SessionAuthentication,
|
BaseAuthentication, BasicAuthentication, RemoteUserAuthentication,
|
||||||
TokenAuthentication)
|
SessionAuthentication, TokenAuthentication
|
||||||
|
)
|
||||||
from rest_framework.authtoken.models import Token
|
from rest_framework.authtoken.models import Token
|
||||||
from rest_framework.authtoken.views import obtain_auth_token
|
from rest_framework.authtoken.views import obtain_auth_token
|
||||||
from rest_framework.compat import is_authenticated
|
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.test import APIClient, APIRequestFactory
|
from rest_framework.test import APIClient, APIRequestFactory
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
@ -450,7 +450,7 @@ class FailingAuthAccessedInRenderer(TestCase):
|
||||||
|
|
||||||
def render(self, data, media_type=None, renderer_context=None):
|
def render(self, data, media_type=None, renderer_context=None):
|
||||||
request = renderer_context['request']
|
request = renderer_context['request']
|
||||||
if is_authenticated(request.user):
|
if request.user.is_authenticated:
|
||||||
return b'authenticated'
|
return b'authenticated'
|
||||||
return b'not authenticated'
|
return b'not authenticated'
|
||||||
|
|
||||||
|
|
|
@ -22,20 +22,6 @@ class CompatTests(TestCase):
|
||||||
expected = (timedelta.days * 86400.0) + float(timedelta.seconds) + (timedelta.microseconds / 1000000.0)
|
expected = (timedelta.days * 86400.0) + float(timedelta.seconds) + (timedelta.microseconds / 1000000.0)
|
||||||
assert compat.total_seconds(timedelta) == expected
|
assert compat.total_seconds(timedelta) == expected
|
||||||
|
|
||||||
def test_get_remote_field_with_old_django_version(self):
|
|
||||||
class MockField(object):
|
|
||||||
rel = 'example_rel'
|
|
||||||
compat.django.VERSION = (1, 8)
|
|
||||||
assert compat.get_remote_field(MockField(), default='default_value') == 'example_rel'
|
|
||||||
assert compat.get_remote_field(object(), default='default_value') == 'default_value'
|
|
||||||
|
|
||||||
def test_get_remote_field_with_new_django_version(self):
|
|
||||||
class MockField(object):
|
|
||||||
remote_field = 'example_remote_field'
|
|
||||||
compat.django.VERSION = (1, 10)
|
|
||||||
assert compat.get_remote_field(MockField(), default='default_value') == 'example_remote_field'
|
|
||||||
assert compat.get_remote_field(object(), default='default_value') == 'default_value'
|
|
||||||
|
|
||||||
def test_set_rollback_for_transaction_in_managed_mode(self):
|
def test_set_rollback_for_transaction_in_managed_mode(self):
|
||||||
class MockTransaction(object):
|
class MockTransaction(object):
|
||||||
called_rollback = False
|
called_rollback = False
|
||||||
|
|
|
@ -5,7 +5,7 @@ import pytest
|
||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
|
from django.core.exceptions import ImproperlyConfigured, PermissionDenied
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
from django.template import Template, TemplateDoesNotExist
|
from django.template import TemplateDoesNotExist, engines
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase, override_settings
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
|
@ -60,12 +60,12 @@ class TemplateHTMLRendererTests(TestCase):
|
||||||
|
|
||||||
def get_template(template_name, dirs=None):
|
def get_template(template_name, dirs=None):
|
||||||
if template_name == 'example.html':
|
if template_name == 'example.html':
|
||||||
return Template("example: {{ object }}")
|
return engines['django'].from_string("example: {{ object }}")
|
||||||
raise TemplateDoesNotExist(template_name)
|
raise TemplateDoesNotExist(template_name)
|
||||||
|
|
||||||
def select_template(template_name_list, dirs=None, using=None):
|
def select_template(template_name_list, dirs=None, using=None):
|
||||||
if template_name_list == ['example.html']:
|
if template_name_list == ['example.html']:
|
||||||
return Template("example: {{ object }}")
|
return engines['django'].from_string("example: {{ object }}")
|
||||||
raise TemplateDoesNotExist(template_name_list[0])
|
raise TemplateDoesNotExist(template_name_list[0])
|
||||||
|
|
||||||
django.template.loader.get_template = get_template
|
django.template.loader.get_template = get_template
|
||||||
|
@ -139,9 +139,9 @@ class TemplateHTMLRendererExceptionTests(TestCase):
|
||||||
|
|
||||||
def get_template(template_name):
|
def get_template(template_name):
|
||||||
if template_name == '404.html':
|
if template_name == '404.html':
|
||||||
return Template("404: {{ detail }}")
|
return engines['django'].from_string("404: {{ detail }}")
|
||||||
if template_name == '403.html':
|
if template_name == '403.html':
|
||||||
return Template("403: {{ detail }}")
|
return engines['django'].from_string("403: {{ detail }}")
|
||||||
raise TemplateDoesNotExist(template_name)
|
raise TemplateDoesNotExist(template_name)
|
||||||
|
|
||||||
django.template.loader.get_template = get_template
|
django.template.loader.get_template = get_template
|
||||||
|
|
|
@ -21,7 +21,7 @@ from django.test import TestCase
|
||||||
from django.utils import six
|
from django.utils import six
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from rest_framework.compat import set_many, unicode_repr
|
from rest_framework.compat import unicode_repr
|
||||||
|
|
||||||
|
|
||||||
def dedent(blocktext):
|
def dedent(blocktext):
|
||||||
|
@ -703,8 +703,7 @@ class TestIntegration(TestCase):
|
||||||
foreign_key=self.foreign_key_target,
|
foreign_key=self.foreign_key_target,
|
||||||
one_to_one=self.one_to_one_target,
|
one_to_one=self.one_to_one_target,
|
||||||
)
|
)
|
||||||
set_many(self.instance, 'many_to_many', self.many_to_many_targets)
|
self.instance.many_to_many.set(self.many_to_many_targets)
|
||||||
self.instance.save()
|
|
||||||
|
|
||||||
def test_pk_retrival(self):
|
def test_pk_retrival(self):
|
||||||
class TestSerializer(serializers.ModelSerializer):
|
class TestSerializer(serializers.ModelSerializer):
|
||||||
|
|
|
@ -5,8 +5,6 @@ from django.test import TestCase
|
||||||
|
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from tests.models import RESTFrameworkModel
|
from tests.models import RESTFrameworkModel
|
||||||
|
|
||||||
|
|
||||||
# Models
|
# Models
|
||||||
from tests.test_multitable_inheritance import ChildModel
|
from tests.test_multitable_inheritance import ChildModel
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ from rest_framework import (
|
||||||
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
|
HTTP_HEADER_ENCODING, authentication, generics, permissions, serializers,
|
||||||
status, views
|
status, views
|
||||||
)
|
)
|
||||||
from rest_framework.compat import ResolverMatch, guardian, set_many
|
from rest_framework.compat import ResolverMatch, guardian
|
||||||
from rest_framework.filters import DjangoObjectPermissionsFilter
|
from rest_framework.filters import DjangoObjectPermissionsFilter
|
||||||
from rest_framework.routers import DefaultRouter
|
from rest_framework.routers import DefaultRouter
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
|
@ -73,13 +73,14 @@ class ModelPermissionsIntegrationTests(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
|
User.objects.create_user('disallowed', 'disallowed@example.com', 'password')
|
||||||
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
|
user = User.objects.create_user('permitted', 'permitted@example.com', 'password')
|
||||||
set_many(user, 'user_permissions', [
|
user.user_permissions.set([
|
||||||
Permission.objects.get(codename='add_basicmodel'),
|
Permission.objects.get(codename='add_basicmodel'),
|
||||||
Permission.objects.get(codename='change_basicmodel'),
|
Permission.objects.get(codename='change_basicmodel'),
|
||||||
Permission.objects.get(codename='delete_basicmodel')
|
Permission.objects.get(codename='delete_basicmodel')
|
||||||
])
|
])
|
||||||
|
|
||||||
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
|
user = User.objects.create_user('updateonly', 'updateonly@example.com', 'password')
|
||||||
set_many(user, 'user_permissions', [
|
user.user_permissions.set([
|
||||||
Permission.objects.get(codename='change_basicmodel'),
|
Permission.objects.get(codename='change_basicmodel'),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,6 @@ from django.contrib.auth.models import Group, User
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
|
|
||||||
from rest_framework import generics, serializers
|
from rest_framework import generics, serializers
|
||||||
from rest_framework.compat import set_many
|
|
||||||
from rest_framework.test import APIRequestFactory
|
from rest_framework.test import APIRequestFactory
|
||||||
|
|
||||||
factory = APIRequestFactory()
|
factory = APIRequestFactory()
|
||||||
|
@ -23,8 +22,7 @@ class TestPrefetchRelatedUpdates(TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.user = User.objects.create(username='tom', email='tom@example.com')
|
self.user = User.objects.create(username='tom', email='tom@example.com')
|
||||||
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
self.groups = [Group.objects.create(name='a'), Group.objects.create(name='b')]
|
||||||
set_many(self.user, 'groups', self.groups)
|
self.user.groups.set(self.groups)
|
||||||
self.user.save()
|
|
||||||
|
|
||||||
def test_prefetch_related_updates(self):
|
def test_prefetch_related_updates(self):
|
||||||
view = UserUpdate.as_view()
|
view = UserUpdate.as_view()
|
||||||
|
|
|
@ -6,8 +6,9 @@ from django.utils import six
|
||||||
from rest_framework import serializers
|
from rest_framework import serializers
|
||||||
from tests.models import (
|
from tests.models import (
|
||||||
ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget,
|
ForeignKeySource, ForeignKeyTarget, ManyToManySource, ManyToManyTarget,
|
||||||
NullableForeignKeySource, NullableOneToOneSource, NullableUUIDForeignKeySource,
|
NullableForeignKeySource, NullableOneToOneSource,
|
||||||
OneToOnePKSource, OneToOneTarget, UUIDForeignKeyTarget
|
NullableUUIDForeignKeySource, OneToOnePKSource, OneToOneTarget,
|
||||||
|
UUIDForeignKeyTarget
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,6 @@ from django.utils import six
|
||||||
|
|
||||||
from rest_framework import status
|
from rest_framework import status
|
||||||
from rest_framework.authentication import SessionAuthentication
|
from rest_framework.authentication import SessionAuthentication
|
||||||
from rest_framework.compat import is_anonymous
|
|
||||||
from rest_framework.parsers import BaseParser, FormParser, MultiPartParser
|
from rest_framework.parsers import BaseParser, FormParser, MultiPartParser
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
@ -201,9 +200,9 @@ class TestUserSetter(TestCase):
|
||||||
|
|
||||||
def test_user_can_logout(self):
|
def test_user_can_logout(self):
|
||||||
self.request.user = self.user
|
self.request.user = self.user
|
||||||
self.assertFalse(is_anonymous(self.request.user))
|
self.assertFalse(self.request.user.is_anonymous)
|
||||||
logout(self.request)
|
logout(self.request)
|
||||||
self.assertTrue(is_anonymous(self.request.user))
|
self.assertTrue(self.request.user.is_anonymous)
|
||||||
|
|
||||||
def test_logged_in_user_is_set_on_wrapped_request(self):
|
def test_logged_in_user_is_set_on_wrapped_request(self):
|
||||||
login(self.request, self.user)
|
login(self.request, self.user)
|
||||||
|
|
|
@ -10,7 +10,7 @@ from django.test import override_settings
|
||||||
from django.utils.decorators import method_decorator
|
from django.utils.decorators import method_decorator
|
||||||
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
|
from django.views.decorators.csrf import csrf_protect, ensure_csrf_cookie
|
||||||
|
|
||||||
from rest_framework.compat import is_authenticated, requests
|
from rest_framework.compat import requests
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.test import APITestCase, RequestsClient
|
from rest_framework.test import APITestCase, RequestsClient
|
||||||
from rest_framework.views import APIView
|
from rest_framework.views import APIView
|
||||||
|
@ -72,7 +72,7 @@ class SessionView(APIView):
|
||||||
class AuthView(APIView):
|
class AuthView(APIView):
|
||||||
@method_decorator(ensure_csrf_cookie)
|
@method_decorator(ensure_csrf_cookie)
|
||||||
def get(self, request):
|
def get(self, request):
|
||||||
if is_authenticated(request.user):
|
if request.user.is_authenticated:
|
||||||
username = request.user.username
|
username = request.user.username
|
||||||
else:
|
else:
|
||||||
username = None
|
username = None
|
||||||
|
|
|
@ -3,13 +3,12 @@ from __future__ import unicode_literals
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from django.conf.urls import url
|
from django.conf.urls import include, url
|
||||||
from django.core.exceptions import ImproperlyConfigured
|
from django.core.exceptions import ImproperlyConfigured
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.test import TestCase, override_settings
|
from django.test import TestCase, override_settings
|
||||||
|
|
||||||
from rest_framework import permissions, serializers, viewsets
|
from rest_framework import permissions, serializers, viewsets
|
||||||
from rest_framework.compat import include
|
|
||||||
from rest_framework.decorators import detail_route, list_route
|
from rest_framework.decorators import detail_route, list_route
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
from rest_framework.routers import DefaultRouter, SimpleRouter
|
from rest_framework.routers import DefaultRouter, SimpleRouter
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
from django.conf.urls import url
|
from django.conf.urls import include, url
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
|
||||||
from rest_framework import serializers, status, versioning
|
from rest_framework import serializers, status, versioning
|
||||||
from rest_framework.compat import include
|
|
||||||
from rest_framework.decorators import APIView
|
from rest_framework.decorators import APIView
|
||||||
from rest_framework.relations import PKOnlyObject
|
from rest_framework.relations import PKOnlyObject
|
||||||
from rest_framework.response import Response
|
from rest_framework.response import Response
|
||||||
|
|
|
@ -4,6 +4,7 @@ URLConf for test suite.
|
||||||
We need only the docs urls for DocumentationRenderer tests.
|
We need only the docs urls for DocumentationRenderer tests.
|
||||||
"""
|
"""
|
||||||
from django.conf.urls import url
|
from django.conf.urls import url
|
||||||
|
|
||||||
from rest_framework.documentation import include_docs_urls
|
from rest_framework.documentation import include_docs_urls
|
||||||
|
|
||||||
urlpatterns = [
|
urlpatterns = [
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
from django.core.exceptions import ObjectDoesNotExist
|
from django.core.exceptions import ObjectDoesNotExist
|
||||||
|
|
||||||
from rest_framework.compat import NoReverseMatch
|
from rest_framework.compat import NoReverseMatch
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user