black/flake8/isort fix

This commit is contained in:
Rizwan Mansuri 2019-04-13 12:16:48 +01:00
parent f2eacd3660
commit f418b4e9ca
134 changed files with 10876 additions and 8988 deletions

View File

@ -4,7 +4,7 @@ flake8-tidy-imports==1.1.0
pycodestyle==2.3.1
# Sort and lint imports
isort==4.3.3
isort==4.3.17
# black
black==19.3b0

View File

@ -7,22 +7,22 @@ ______ _____ _____ _____ __
\_| \_\____/\____/ \_/ |_| |_| \__,_|_| |_| |_|\___| \_/\_/ \___/|_| |_|\_|
"""
__title__ = 'Django REST framework'
__version__ = '3.9.2'
__author__ = 'Tom Christie'
__license__ = 'BSD 2-Clause'
__copyright__ = 'Copyright 2011-2019 Encode OSS Ltd'
__title__ = "Django REST framework"
__version__ = "3.9.2"
__author__ = "Tom Christie"
__license__ = "BSD 2-Clause"
__copyright__ = "Copyright 2011-2019 Encode OSS Ltd"
# Version synonym
VERSION = __version__
# Header encoding (see RFC5987)
HTTP_HEADER_ENCODING = 'iso-8859-1'
HTTP_HEADER_ENCODING = "iso-8859-1"
# Default datetime input and output formats
ISO_8601 = 'iso-8601'
ISO_8601 = "iso-8601"
default_app_config = 'rest_framework.apps.RestFrameworkConfig'
default_app_config = "rest_framework.apps.RestFrameworkConfig"
class RemovedInDRF310Warning(DeprecationWarning):

View File

@ -2,7 +2,7 @@ from django.apps import AppConfig
class RestFrameworkConfig(AppConfig):
name = 'rest_framework'
name = "rest_framework"
verbose_name = "Django REST framework"
def ready(self):

View File

@ -20,7 +20,7 @@ def get_authorization_header(request):
Hide some test client ickyness where the header can be unicode.
"""
auth = request.META.get('HTTP_AUTHORIZATION', b'')
auth = request.META.get("HTTP_AUTHORIZATION", b"")
if isinstance(auth, text_type):
# Work around django test client oddness
auth = auth.encode(HTTP_HEADER_ENCODING)
@ -57,7 +57,8 @@ class BasicAuthentication(BaseAuthentication):
"""
HTTP Basic authentication against username/password.
"""
www_authenticate_realm = 'api'
www_authenticate_realm = "api"
def authenticate(self, request):
"""
@ -66,20 +67,24 @@ class BasicAuthentication(BaseAuthentication):
"""
auth = get_authorization_header(request).split()
if not auth or auth[0].lower() != b'basic':
if not auth or auth[0].lower() != b"basic":
return None
if len(auth) == 1:
msg = _('Invalid basic header. No credentials provided.')
msg = _("Invalid basic header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _('Invalid basic header. Credentials string should not contain spaces.')
msg = _(
"Invalid basic header. Credentials string should not contain spaces."
)
raise exceptions.AuthenticationFailed(msg)
try:
auth_parts = base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(':')
auth_parts = (
base64.b64decode(auth[1]).decode(HTTP_HEADER_ENCODING).partition(":")
)
except (TypeError, UnicodeDecodeError, binascii.Error):
msg = _('Invalid basic header. Credentials not correctly base64 encoded.')
msg = _("Invalid basic header. Credentials not correctly base64 encoded.")
raise exceptions.AuthenticationFailed(msg)
userid, password = auth_parts[0], auth_parts[2]
@ -90,17 +95,14 @@ class BasicAuthentication(BaseAuthentication):
Authenticate the userid and password against username and password
with optional request for context.
"""
credentials = {
get_user_model().USERNAME_FIELD: userid,
'password': password
}
credentials = {get_user_model().USERNAME_FIELD: userid, "password": password}
user = authenticate(request=request, **credentials)
if user is None:
raise exceptions.AuthenticationFailed(_('Invalid username/password.'))
raise exceptions.AuthenticationFailed(_("Invalid username/password."))
if not user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
raise exceptions.AuthenticationFailed(_("User inactive or deleted."))
return (user, None)
@ -120,7 +122,7 @@ class SessionAuthentication(BaseAuthentication):
"""
# Get the session-based user from the underlying HttpRequest object
user = getattr(request._request, 'user', None)
user = getattr(request._request, "user", None)
# Unauthenticated, CSRF validation not required
if not user or not user.is_active:
@ -141,7 +143,7 @@ class SessionAuthentication(BaseAuthentication):
reason = check.process_view(request, None, (), {})
if reason:
# CSRF failed, bail with explicit error message
raise exceptions.PermissionDenied('CSRF Failed: %s' % reason)
raise exceptions.PermissionDenied("CSRF Failed: %s" % reason)
class TokenAuthentication(BaseAuthentication):
@ -154,13 +156,14 @@ class TokenAuthentication(BaseAuthentication):
Authorization: Token 401f7ac837da42b97f613d789819ff93537bee6a
"""
keyword = 'Token'
keyword = "Token"
model = None
def get_model(self):
if self.model is not None:
return self.model
from rest_framework.authtoken.models import Token
return Token
"""
@ -177,16 +180,18 @@ class TokenAuthentication(BaseAuthentication):
return None
if len(auth) == 1:
msg = _('Invalid token header. No credentials provided.')
msg = _("Invalid token header. No credentials provided.")
raise exceptions.AuthenticationFailed(msg)
elif len(auth) > 2:
msg = _('Invalid token header. Token string should not contain spaces.')
msg = _("Invalid token header. Token string should not contain spaces.")
raise exceptions.AuthenticationFailed(msg)
try:
token = auth[1].decode()
except UnicodeError:
msg = _('Invalid token header. Token string should not contain invalid characters.')
msg = _(
"Invalid token header. Token string should not contain invalid characters."
)
raise exceptions.AuthenticationFailed(msg)
return self.authenticate_credentials(token)
@ -194,12 +199,12 @@ class TokenAuthentication(BaseAuthentication):
def authenticate_credentials(self, key):
model = self.get_model()
try:
token = model.objects.select_related('user').get(key=key)
token = model.objects.select_related("user").get(key=key)
except model.DoesNotExist:
raise exceptions.AuthenticationFailed(_('Invalid token.'))
raise exceptions.AuthenticationFailed(_("Invalid token."))
if not token.user.is_active:
raise exceptions.AuthenticationFailed(_('User inactive or deleted.'))
raise exceptions.AuthenticationFailed(_("User inactive or deleted."))
return (token.user, token)

View File

@ -1 +1 @@
default_app_config = 'rest_framework.authtoken.apps.AuthTokenConfig'
default_app_config = "rest_framework.authtoken.apps.AuthTokenConfig"

View File

@ -4,9 +4,9 @@ from rest_framework.authtoken.models import Token
class TokenAdmin(admin.ModelAdmin):
list_display = ('key', 'user', 'created')
fields = ('user',)
ordering = ('-created',)
list_display = ("key", "user", "created")
fields = ("user",)
ordering = ("-created",)
admin.site.register(Token, TokenAdmin)

View File

@ -3,5 +3,5 @@ from django.utils.translation import ugettext_lazy as _
class AuthTokenConfig(AppConfig):
name = 'rest_framework.authtoken'
name = "rest_framework.authtoken"
verbose_name = _("Auth Token")

View File

@ -3,11 +3,12 @@ from django.core.management.base import BaseCommand, CommandError
from rest_framework.authtoken.models import Token
UserModel = get_user_model()
class Command(BaseCommand):
help = 'Create DRF Token for a given user'
help = "Create DRF Token for a given user"
def create_user_token(self, username, reset_token):
user = UserModel._default_manager.get_by_natural_key(username)
@ -19,27 +20,27 @@ class Command(BaseCommand):
return token[0]
def add_arguments(self, parser):
parser.add_argument('username', type=str)
parser.add_argument("username", type=str)
parser.add_argument(
'-r',
'--reset',
action='store_true',
dest='reset_token',
"-r",
"--reset",
action="store_true",
dest="reset_token",
default=False,
help='Reset existing User token and create a new one',
help="Reset existing User token and create a new one",
)
def handle(self, *args, **options):
username = options['username']
reset_token = options['reset_token']
username = options["username"]
reset_token = options["reset_token"]
try:
token = self.create_user_token(username, reset_token)
except UserModel.DoesNotExist:
raise CommandError(
'Cannot create the Token: user {0} does not exist'.format(
username)
"Cannot create the Token: user {0} does not exist".format(username)
)
self.stdout.write(
'Generated token {0} for user {1}'.format(token.key, username))
"Generated token {0} for user {1}".format(token.key, username)
)

View File

@ -7,20 +7,27 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)]
operations = [
migrations.CreateModel(
name='Token',
name="Token",
fields=[
('key', models.CharField(primary_key=True, serialize=False, max_length=40)),
('created', models.DateTimeField(auto_now_add=True)),
('user', models.OneToOneField(to=settings.AUTH_USER_MODEL, related_name='auth_token', on_delete=models.CASCADE)),
(
"key",
models.CharField(primary_key=True, serialize=False, max_length=40),
),
("created", models.DateTimeField(auto_now_add=True)),
(
"user",
models.OneToOneField(
to=settings.AUTH_USER_MODEL,
related_name="auth_token",
on_delete=models.CASCADE,
),
),
],
options={
},
options={},
bases=(models.Model,),
),
)
]

View File

@ -7,28 +7,33 @@ from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
('authtoken', '0001_initial'),
]
dependencies = [("authtoken", "0001_initial")]
operations = [
migrations.AlterModelOptions(
name='token',
options={'verbose_name_plural': 'Tokens', 'verbose_name': 'Token'},
name="token",
options={"verbose_name_plural": "Tokens", "verbose_name": "Token"},
),
migrations.AlterField(
model_name='token',
name='created',
field=models.DateTimeField(verbose_name='Created', auto_now_add=True),
model_name="token",
name="created",
field=models.DateTimeField(verbose_name="Created", auto_now_add=True),
),
migrations.AlterField(
model_name='token',
name='key',
field=models.CharField(verbose_name='Key', max_length=40, primary_key=True, serialize=False),
model_name="token",
name="key",
field=models.CharField(
verbose_name="Key", max_length=40, primary_key=True, serialize=False
),
),
migrations.AlterField(
model_name='token',
name='user',
field=models.OneToOneField(to=settings.AUTH_USER_MODEL, verbose_name='User', related_name='auth_token', on_delete=models.CASCADE),
model_name="token",
name="user",
field=models.OneToOneField(
to=settings.AUTH_USER_MODEL,
verbose_name="User",
related_name="auth_token",
on_delete=models.CASCADE,
),
),
]

View File

@ -12,10 +12,13 @@ class Token(models.Model):
"""
The default authorization token model.
"""
key = models.CharField(_("Key"), max_length=40, primary_key=True)
user = models.OneToOneField(
settings.AUTH_USER_MODEL, related_name='auth_token',
on_delete=models.CASCADE, verbose_name=_("User")
settings.AUTH_USER_MODEL,
related_name="auth_token",
on_delete=models.CASCADE,
verbose_name=_("User"),
)
created = models.DateTimeField(_("Created"), auto_now_add=True)
@ -25,7 +28,7 @@ class Token(models.Model):
#
# Also see corresponding ticket:
# https://github.com/encode/django-rest-framework/issues/705
abstract = 'rest_framework.authtoken' not in settings.INSTALLED_APPS
abstract = "rest_framework.authtoken" not in settings.INSTALLED_APPS
verbose_name = _("Token")
verbose_name_plural = _("Tokens")

View File

@ -7,28 +7,29 @@ from rest_framework import serializers
class AuthTokenSerializer(serializers.Serializer):
username = serializers.CharField(label=_("Username"))
password = serializers.CharField(
label=_("Password"),
style={'input_type': 'password'},
trim_whitespace=False
label=_("Password"), style={"input_type": "password"}, trim_whitespace=False
)
def validate(self, attrs):
username = attrs.get('username')
password = attrs.get('password')
username = attrs.get("username")
password = attrs.get("password")
if username and password:
user = authenticate(request=self.context.get('request'),
username=username, password=password)
user = authenticate(
request=self.context.get("request"),
username=username,
password=password,
)
# The authenticate call simply returns None for is_active=False
# users. (Assuming the default ModelBackend authentication
# backend.)
if not user:
msg = _('Unable to log in with provided credentials.')
raise serializers.ValidationError(msg, code='authorization')
msg = _("Unable to log in with provided credentials.")
raise serializers.ValidationError(msg, code="authorization")
else:
msg = _('Must include "username" and "password".')
raise serializers.ValidationError(msg, code='authorization')
raise serializers.ValidationError(msg, code="authorization")
attrs['user'] = user
attrs["user"] = user
return attrs

View File

@ -10,7 +10,7 @@ from rest_framework.views import APIView
class ObtainAuthToken(APIView):
throttle_classes = ()
permission_classes = ()
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser,)
parser_classes = (parsers.FormParser, parsers.MultiPartParser, parsers.JSONParser)
renderer_classes = (renderers.JSONRenderer,)
serializer_class = AuthTokenSerializer
if coreapi is not None and coreschema is not None:
@ -19,7 +19,7 @@ class ObtainAuthToken(APIView):
coreapi.Field(
name="username",
required=True,
location='form',
location="form",
schema=coreschema.String(
title="Username",
description="Valid username for authentication",
@ -28,7 +28,7 @@ class ObtainAuthToken(APIView):
coreapi.Field(
name="password",
required=True,
location='form',
location="form",
schema=coreschema.String(
title="Password",
description="Valid password for authentication",
@ -39,12 +39,13 @@ class ObtainAuthToken(APIView):
)
def post(self, request, *args, **kwargs):
serializer = self.serializer_class(data=request.data,
context={'request': request})
serializer = self.serializer_class(
data=request.data, context={"request": request}
)
serializer.is_valid(raise_exception=True)
user = serializer.validated_data['user']
user = serializer.validated_data["user"]
token, created = Token.objects.get_or_create(user=user)
return Response({'token': token.key})
return Response({"token": token.key})
obtain_auth_token = ObtainAuthToken.as_view()

View File

@ -6,16 +6,17 @@ def pagination_system_check(app_configs, **kwargs):
errors = []
# Use of default page size setting requires a default Paginator class
from rest_framework.settings import api_settings
if api_settings.PAGE_SIZE and not api_settings.DEFAULT_PAGINATION_CLASS:
errors.append(
Warning(
"You have specified a default PAGE_SIZE pagination rest_framework setting,"
"without specifying also a DEFAULT_PAGINATION_CLASS.",
hint="The default for DEFAULT_PAGINATION_CLASS is None. "
"In previous versions this was PageNumberPagination. "
"If you wish to define PAGE_SIZE globally whilst defining "
"pagination_class on a per-view basis you may silence this check.",
id="rest_framework.W001"
"In previous versions this was PageNumberPagination. "
"If you wish to define PAGE_SIZE globally whilst defining "
"pagination_class on a per-view basis you may silence this check.",
id="rest_framework.W001",
)
)
return errors

View File

@ -12,18 +12,16 @@ from django.core import validators
from django.utils import six
from django.views.generic import View
try:
# Python 3
from collections.abc import Mapping, MutableMapping # noqa
except ImportError:
# Python 2.7
from collections import Mapping, MutableMapping # noqa
try:
from django.urls import ( # noqa
URLPattern,
URLResolver,
)
# Python 3
from collections.abc import Mapping, MutableMapping # noqa
except ImportError:
# Python 2.7
from collections import Mapping, MutableMapping # noqa
try:
from django.urls import URLPattern, URLResolver # noqa
except ImportError:
# Will be removed in Django 2.0
from django.urls import ( # noqa
@ -47,7 +45,7 @@ def get_original_route(urlpattern):
Get the original route/regex that was typed in by the user into the path(), re_path() or url() directive. This
is in contrast with get_regex_pattern below, which for RoutePattern returns the raw regex generated from the path().
"""
if hasattr(urlpattern, 'pattern'):
if hasattr(urlpattern, "pattern"):
# Django 2.0
return str(urlpattern.pattern)
else:
@ -60,7 +58,7 @@ def get_regex_pattern(urlpattern):
Get the raw regex out of the urlpattern's RegexPattern or RoutePattern. This is always a regular expression,
unlike get_original_route above.
"""
if hasattr(urlpattern, 'pattern'):
if hasattr(urlpattern, "pattern"):
# Django 2.0
return urlpattern.pattern.regex.pattern
else:
@ -69,9 +67,10 @@ def get_regex_pattern(urlpattern):
def is_route_pattern(urlpattern):
if hasattr(urlpattern, 'pattern'):
if hasattr(urlpattern, "pattern"):
# Django 2.0
from django.urls.resolvers import RoutePattern
return isinstance(urlpattern.pattern, RoutePattern)
else:
# Django < 2.0
@ -82,6 +81,7 @@ def make_url_resolver(regex, urlpatterns):
try:
# Django 2.0
from django.urls.resolvers import RegexPattern
return URLResolver(RegexPattern(regex), urlpatterns)
except ImportError:
@ -93,7 +93,7 @@ def unicode_repr(instance):
# Get the repr of an instance, but ensure it is a unicode string
# on both python 3 (already the case) and 2 (not the case).
if six.PY2:
return repr(instance).decode('utf-8')
return repr(instance).decode("utf-8")
return repr(instance)
@ -102,21 +102,21 @@ def unicode_to_repr(value):
# the Python version. We wrap all our `__repr__` implementations with
# this and then use unicode throughout internally.
if six.PY2:
return value.encode('utf-8')
return value.encode("utf-8")
return value
def unicode_http_header(value):
# Coerce HTTP header value to unicode.
if isinstance(value, bytes):
return value.decode('iso-8859-1')
return value.decode("iso-8859-1")
return value
def distinct(queryset, base):
if settings.DATABASES[queryset.db]["ENGINE"] == "django.db.backends.oracle":
# distinct analogue for Oracle users
return base.filter(pk__in=set(queryset.values_list('pk', flat=True)))
return base.filter(pk__in=set(queryset.values_list("pk", flat=True)))
return queryset.distinct()
@ -172,27 +172,27 @@ def is_guardian_installed():
# Guardian 1.5.0, for Django 2.2 is NOT compatible with Python 2.7.
# Remove when dropping PY2.
return False
return 'guardian' in settings.INSTALLED_APPS
return "guardian" in settings.INSTALLED_APPS
# PATCH method is not implemented by Django
if 'patch' not in View.http_method_names:
View.http_method_names = View.http_method_names + ['patch']
if "patch" not in View.http_method_names:
View.http_method_names = View.http_method_names + ["patch"]
# Markdown is optional
try:
import markdown
if markdown.version <= '2.2':
HEADERID_EXT_PATH = 'headerid'
LEVEL_PARAM = 'level'
elif markdown.version < '2.6':
HEADERID_EXT_PATH = 'markdown.extensions.headerid'
LEVEL_PARAM = 'level'
if markdown.version <= "2.2":
HEADERID_EXT_PATH = "headerid"
LEVEL_PARAM = "level"
elif markdown.version < "2.6":
HEADERID_EXT_PATH = "markdown.extensions.headerid"
LEVEL_PARAM = "level"
else:
HEADERID_EXT_PATH = 'markdown.extensions.toc'
LEVEL_PARAM = 'baselevel'
HEADERID_EXT_PATH = "markdown.extensions.toc"
LEVEL_PARAM = "baselevel"
def apply_markdown(text):
"""
@ -200,16 +200,14 @@ try:
of '#' style headers to <h2>.
"""
extensions = [HEADERID_EXT_PATH]
extension_configs = {
HEADERID_EXT_PATH: {
LEVEL_PARAM: '2'
}
}
extension_configs = {HEADERID_EXT_PATH: {LEVEL_PARAM: "2"}}
md = markdown.Markdown(
extensions=extensions, extension_configs=extension_configs
)
md_filter_add_syntax_highlight(md)
return md.convert(text)
except ImportError:
apply_markdown = None
markdown = None
@ -227,7 +225,8 @@ try:
def pygments_css(style):
formatter = HtmlFormatter(style=style)
return formatter.get_style_defs('.highlight')
return formatter.get_style_defs(".highlight")
except ImportError:
pygments = None
@ -238,6 +237,7 @@ except ImportError:
def pygments_css(style):
return None
if markdown is not None and pygments is not None:
# starting from this blogpost and modified to support current markdown extensions API
# https://zerokspot.com/weblog/2008/06/18/syntax-highlighting-in-markdown-with-pygments/
@ -246,8 +246,7 @@ if markdown is not None and pygments is not None:
import re
class CodeBlockPreprocessor(Preprocessor):
pattern = re.compile(
r'^\s*``` *([^\n]+)\n(.+?)^\s*```', re.M | re.S)
pattern = re.compile(r"^\s*``` *([^\n]+)\n(.+?)^\s*```", re.M | re.S)
formatter = HtmlFormatter()
@ -257,17 +256,25 @@ if markdown is not None and pygments is not None:
lexer = get_lexer_by_name(m.group(1))
except (ValueError, NameError):
lexer = TextLexer()
code = m.group(2).replace('\t', ' ')
code = m.group(2).replace("\t", " ")
code = pygments.highlight(code, lexer, self.formatter)
code = code.replace('\n\n', '\n&nbsp;\n').replace('\n', '<br />').replace('\\@', '@')
return '\n\n%s\n\n' % code
code = (
code.replace("\n\n", "\n&nbsp;\n")
.replace("\n", "<br />")
.replace("\\@", "@")
)
return "\n\n%s\n\n" % code
ret = self.pattern.sub(repl, "\n".join(lines))
return ret.split("\n")
def md_filter_add_syntax_highlight(md):
md.preprocessors.add('highlight', CodeBlockPreprocessor(), "_begin")
md.preprocessors.add("highlight", CodeBlockPreprocessor(), "_begin")
return True
else:
def md_filter_add_syntax_highlight(md):
return False
@ -276,7 +283,8 @@ else:
try:
from django.urls import include, path, re_path, register_converter # noqa
except ImportError:
from django.conf.urls import include, url # noqa
from django.conf.urls import include, url # noqa
path = None
register_converter = None
re_path = url
@ -285,13 +293,13 @@ except ImportError:
# `separators` argument to `json.dumps()` differs between 2.x and 3.x
# See: https://bugs.python.org/issue22767
if six.PY3:
SHORT_SEPARATORS = (',', ':')
LONG_SEPARATORS = (', ', ': ')
INDENT_SEPARATORS = (',', ': ')
SHORT_SEPARATORS = (",", ":")
LONG_SEPARATORS = (", ", ": ")
INDENT_SEPARATORS = (",", ": ")
else:
SHORT_SEPARATORS = (b',', b':')
LONG_SEPARATORS = (b', ', b': ')
INDENT_SEPARATORS = (b',', b': ')
SHORT_SEPARATORS = (b",", b":")
LONG_SEPARATORS = (b", ", b": ")
INDENT_SEPARATORS = (b",", b": ")
class CustomValidatorMessage(object):
@ -303,7 +311,7 @@ class CustomValidatorMessage(object):
"""
def __init__(self, *args, **kwargs):
self.message = kwargs.pop('message', self.message)
self.message = kwargs.pop("message", self.message)
super(CustomValidatorMessage, self).__init__(*args, **kwargs)

View File

@ -23,14 +23,14 @@ def api_view(http_method_names=None):
Decorator that converts a function-based view into an APIView subclass.
Takes a list of allowed methods for the view as an argument.
"""
http_method_names = ['GET'] if (http_method_names is None) else http_method_names
http_method_names = ["GET"] if (http_method_names is None) else http_method_names
def decorator(func):
WrappedAPIView = type(
six.PY3 and 'WrappedAPIView' or b'WrappedAPIView',
six.PY3 and "WrappedAPIView" or b"WrappedAPIView",
(APIView,),
{'__doc__': func.__doc__}
{"__doc__": func.__doc__},
)
# Note, the above allows us to set the docstring.
@ -41,15 +41,20 @@ def api_view(http_method_names=None):
# WrappedAPIView.__doc__ = func.doc <--- Not possible to do this
# api_view applied without (method_names)
assert not(isinstance(http_method_names, types.FunctionType)), \
'@api_view missing list of allowed HTTP methods'
assert not (
isinstance(http_method_names, types.FunctionType)
), "@api_view missing list of allowed HTTP methods"
# api_view applied with eg. string instead of list of strings
assert isinstance(http_method_names, (list, tuple)), \
'@api_view expected a list of strings, received %s' % type(http_method_names).__name__
assert isinstance(http_method_names, (list, tuple)), (
"@api_view expected a list of strings, received %s"
% type(http_method_names).__name__
)
allowed_methods = set(http_method_names) | {'options'}
WrappedAPIView.http_method_names = [method.lower() for method in allowed_methods]
allowed_methods = set(http_method_names) | {"options"}
WrappedAPIView.http_method_names = [
method.lower() for method in allowed_methods
]
def handler(self, *args, **kwargs):
return func(*args, **kwargs)
@ -60,23 +65,27 @@ def api_view(http_method_names=None):
WrappedAPIView.__name__ = func.__name__
WrappedAPIView.__module__ = func.__module__
WrappedAPIView.renderer_classes = getattr(func, 'renderer_classes',
APIView.renderer_classes)
WrappedAPIView.renderer_classes = getattr(
func, "renderer_classes", APIView.renderer_classes
)
WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
APIView.parser_classes)
WrappedAPIView.parser_classes = getattr(
func, "parser_classes", APIView.parser_classes
)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
APIView.authentication_classes)
WrappedAPIView.authentication_classes = getattr(
func, "authentication_classes", APIView.authentication_classes
)
WrappedAPIView.throttle_classes = getattr(func, 'throttle_classes',
APIView.throttle_classes)
WrappedAPIView.throttle_classes = getattr(
func, "throttle_classes", APIView.throttle_classes
)
WrappedAPIView.permission_classes = getattr(func, 'permission_classes',
APIView.permission_classes)
WrappedAPIView.permission_classes = getattr(
func, "permission_classes", APIView.permission_classes
)
WrappedAPIView.schema = getattr(func, 'schema',
APIView.schema)
WrappedAPIView.schema = getattr(func, "schema", APIView.schema)
return WrappedAPIView.as_view()
@ -87,6 +96,7 @@ def renderer_classes(renderer_classes):
def decorator(func):
func.renderer_classes = renderer_classes
return func
return decorator
@ -94,6 +104,7 @@ def parser_classes(parser_classes):
def decorator(func):
func.parser_classes = parser_classes
return func
return decorator
@ -101,6 +112,7 @@ def authentication_classes(authentication_classes):
def decorator(func):
func.authentication_classes = authentication_classes
return func
return decorator
@ -108,6 +120,7 @@ def throttle_classes(throttle_classes):
def decorator(func):
func.throttle_classes = throttle_classes
return func
return decorator
@ -115,6 +128,7 @@ def permission_classes(permission_classes):
def decorator(func):
func.permission_classes = permission_classes
return func
return decorator
@ -122,6 +136,7 @@ def schema(view_inspector):
def decorator(func):
func.schema = view_inspector
return func
return decorator
@ -132,15 +147,13 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
Set the `detail` boolean to determine if this action should apply to
instance/detail requests or collection/list requests.
"""
methods = ['get'] if (methods is None) else methods
methods = ["get"] if (methods is None) else methods
methods = [method.lower() for method in methods]
assert detail is not None, (
"@action() missing required argument: 'detail'"
)
assert detail is not None, "@action() missing required argument: 'detail'"
# name and suffix are mutually exclusive
if 'name' in kwargs and 'suffix' in kwargs:
if "name" in kwargs and "suffix" in kwargs:
raise TypeError("`name` and `suffix` are mutually exclusive arguments.")
def decorator(func):
@ -148,15 +161,16 @@ def action(methods=None, detail=None, url_path=None, url_name=None, **kwargs):
func.detail = detail
func.url_path = url_path if url_path else func.__name__
func.url_name = url_name if url_name else func.__name__.replace('_', '-')
func.url_name = url_name if url_name else func.__name__.replace("_", "-")
func.kwargs = kwargs
# Set descriptive arguments for viewsets
if 'name' not in kwargs and 'suffix' not in kwargs:
func.kwargs['name'] = pretty_name(func.__name__)
func.kwargs['description'] = func.__doc__ or None
if "name" not in kwargs and "suffix" not in kwargs:
func.kwargs["name"] = pretty_name(func.__name__)
func.kwargs["description"] = func.__doc__ or None
return func
return decorator
@ -184,39 +198,42 @@ class MethodMapper(dict):
self[method] = self.action.__name__
def _map(self, method, func):
assert method not in self, (
"Method '%s' has already been mapped to '.%s'." % (method, self[method]))
assert method not in self, "Method '%s' has already been mapped to '.%s'." % (
method,
self[method],
)
assert func.__name__ != self.action.__name__, (
"Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")
"cannot use the same method name for each mapping declaration."
)
self[method] = func.__name__
return func
def get(self, func):
return self._map('get', func)
return self._map("get", func)
def post(self, func):
return self._map('post', func)
return self._map("post", func)
def put(self, func):
return self._map('put', func)
return self._map("put", func)
def patch(self, func):
return self._map('patch', func)
return self._map("patch", func)
def delete(self, func):
return self._map('delete', func)
return self._map("delete", func)
def head(self, func):
return self._map('head', func)
return self._map("head", func)
def options(self, func):
return self._map('options', func)
return self._map("options", func)
def trace(self, func):
return self._map('trace', func)
return self._map("trace", func)
def detail_route(methods=None, **kwargs):
@ -226,14 +243,16 @@ def detail_route(methods=None, **kwargs):
warnings.warn(
"`detail_route` is deprecated and will be removed in 3.10 in favor of "
"`action`, which accepts a `detail` bool. Use `@action(detail=True)` instead.",
RemovedInDRF310Warning, stacklevel=2
RemovedInDRF310Warning,
stacklevel=2,
)
def decorator(func):
func = action(methods, detail=True, **kwargs)(func)
if 'url_name' not in kwargs:
func.url_name = func.url_path.replace('_', '-')
if "url_name" not in kwargs:
func.url_name = func.url_path.replace("_", "-")
return func
return decorator
@ -244,12 +263,14 @@ def list_route(methods=None, **kwargs):
warnings.warn(
"`list_route` is deprecated and will be removed in 3.10 in favor of "
"`action`, which accepts a `detail` bool. Use `@action(detail=False)` instead.",
RemovedInDRF310Warning, stacklevel=2
RemovedInDRF310Warning,
stacklevel=2,
)
def decorator(func):
func = action(methods, detail=False, **kwargs)(func)
if 'url_name' not in kwargs:
func.url_name = func.url_path.replace('_', '-')
if "url_name" not in kwargs:
func.url_name = func.url_path.replace("_", "-")
return func
return decorator

View File

@ -1,18 +1,25 @@
from django.conf.urls import include, url
from rest_framework.renderers import (
CoreJSONRenderer, DocumentationRenderer, SchemaJSRenderer
CoreJSONRenderer,
DocumentationRenderer,
SchemaJSRenderer,
)
from rest_framework.schemas import SchemaGenerator, get_schema_view
from rest_framework.settings import api_settings
def get_docs_view(
title=None, description=None, schema_url=None, public=True,
patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None):
title=None,
description=None,
schema_url=None,
public=True,
patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None,
):
if renderer_classes is None:
renderer_classes = [DocumentationRenderer, CoreJSONRenderer]
@ -31,10 +38,15 @@ def get_docs_view(
def get_schemajs_view(
title=None, description=None, schema_url=None, public=True,
patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
title=None,
description=None,
schema_url=None,
public=True,
patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
):
renderer_classes = [SchemaJSRenderer]
return get_schema_view(
@ -51,11 +63,16 @@ def get_schemajs_view(
def include_docs_urls(
title=None, description=None, schema_url=None, public=True,
patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None):
title=None,
description=None,
schema_url=None,
public=True,
patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
renderer_classes=None,
):
docs_view = get_docs_view(
title=title,
description=description,
@ -78,7 +95,7 @@ def include_docs_urls(
permission_classes=permission_classes,
)
urls = [
url(r'^$', docs_view, name='docs-index'),
url(r'^schema.js$', schema_js_view, name='schema-js')
url(r"^$", docs_view, name="docs-index"),
url(r"^schema.js$", schema_js_view, name="schema-js"),
]
return include((urls, 'api-docs'), namespace='api-docs')
return include((urls, "api-docs"), namespace="api-docs")

View File

@ -11,8 +11,7 @@ import math
from django.http import JsonResponse
from django.utils import six
from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _
from django.utils.translation import ungettext
from django.utils.translation import ugettext_lazy as _, ungettext
from rest_framework import status
from rest_framework.compat import unicode_to_repr
@ -25,23 +24,20 @@ def _get_error_details(data, default_code=None):
lazy translation strings or strings into `ErrorDetail`.
"""
if isinstance(data, list):
ret = [
_get_error_details(item, default_code) for item in data
]
ret = [_get_error_details(item, default_code) for item in data]
if isinstance(data, ReturnList):
return ReturnList(ret, serializer=data.serializer)
return ret
elif isinstance(data, dict):
ret = {
key: _get_error_details(value, default_code)
for key, value in data.items()
key: _get_error_details(value, default_code) for key, value in data.items()
}
if isinstance(data, ReturnDict):
return ReturnDict(ret, serializer=data.serializer)
return ret
text = force_text(data)
code = getattr(data, 'code', default_code)
code = getattr(data, "code", default_code)
return ErrorDetail(text, code)
@ -58,16 +54,14 @@ def _get_full_details(detail):
return [_get_full_details(item) for item in detail]
elif isinstance(detail, dict):
return {key: _get_full_details(value) for key, value in detail.items()}
return {
'message': detail,
'code': detail.code
}
return {"message": detail, "code": detail.code}
class ErrorDetail(six.text_type):
"""
A string-like object that can additionally have a code.
"""
code = None
def __new__(cls, string, code=None):
@ -86,10 +80,9 @@ class ErrorDetail(six.text_type):
return not self.__eq__(other)
def __repr__(self):
return unicode_to_repr('ErrorDetail(string=%r, code=%r)' % (
six.text_type(self),
self.code,
))
return unicode_to_repr(
"ErrorDetail(string=%r, code=%r)" % (six.text_type(self), self.code)
)
def __hash__(self):
return hash(str(self))
@ -100,9 +93,10 @@ class APIException(Exception):
Base class for REST framework exceptions.
Subclasses should provide `.status_code` and `.default_detail` properties.
"""
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
default_detail = _('A server error occurred.')
default_code = 'error'
default_detail = _("A server error occurred.")
default_code = "error"
def __init__(self, detail=None, code=None):
if detail is None:
@ -139,10 +133,11 @@ class APIException(Exception):
# from rest_framework import serializers
# raise serializers.ValidationError('Value was invalid')
class ValidationError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Invalid input.')
default_code = 'invalid'
default_detail = _("Invalid input.")
default_code = "invalid"
def __init__(self, detail=None, code=None):
if detail is None:
@ -160,38 +155,38 @@ class ValidationError(APIException):
class ParseError(APIException):
status_code = status.HTTP_400_BAD_REQUEST
default_detail = _('Malformed request.')
default_code = 'parse_error'
default_detail = _("Malformed request.")
default_code = "parse_error"
class AuthenticationFailed(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Incorrect authentication credentials.')
default_code = 'authentication_failed'
default_detail = _("Incorrect authentication credentials.")
default_code = "authentication_failed"
class NotAuthenticated(APIException):
status_code = status.HTTP_401_UNAUTHORIZED
default_detail = _('Authentication credentials were not provided.')
default_code = 'not_authenticated'
default_detail = _("Authentication credentials were not provided.")
default_code = "not_authenticated"
class PermissionDenied(APIException):
status_code = status.HTTP_403_FORBIDDEN
default_detail = _('You do not have permission to perform this action.')
default_code = 'permission_denied'
default_detail = _("You do not have permission to perform this action.")
default_code = "permission_denied"
class NotFound(APIException):
status_code = status.HTTP_404_NOT_FOUND
default_detail = _('Not found.')
default_code = 'not_found'
default_detail = _("Not found.")
default_code = "not_found"
class MethodNotAllowed(APIException):
status_code = status.HTTP_405_METHOD_NOT_ALLOWED
default_detail = _('Method "{method}" not allowed.')
default_code = 'method_not_allowed'
default_code = "method_not_allowed"
def __init__(self, method, detail=None, code=None):
if detail is None:
@ -201,8 +196,8 @@ class MethodNotAllowed(APIException):
class NotAcceptable(APIException):
status_code = status.HTTP_406_NOT_ACCEPTABLE
default_detail = _('Could not satisfy the request Accept header.')
default_code = 'not_acceptable'
default_detail = _("Could not satisfy the request Accept header.")
default_code = "not_acceptable"
def __init__(self, detail=None, code=None, available_renderers=None):
self.available_renderers = available_renderers
@ -212,7 +207,7 @@ class NotAcceptable(APIException):
class UnsupportedMediaType(APIException):
status_code = status.HTTP_415_UNSUPPORTED_MEDIA_TYPE
default_detail = _('Unsupported media type "{media_type}" in request.')
default_code = 'unsupported_media_type'
default_code = "unsupported_media_type"
def __init__(self, media_type, detail=None, code=None):
if detail is None:
@ -222,21 +217,28 @@ class UnsupportedMediaType(APIException):
class Throttled(APIException):
status_code = status.HTTP_429_TOO_MANY_REQUESTS
default_detail = _('Request was throttled.')
extra_detail_singular = 'Expected available in {wait} second.'
extra_detail_plural = 'Expected available in {wait} seconds.'
default_code = 'throttled'
default_detail = _("Request was throttled.")
extra_detail_singular = "Expected available in {wait} second."
extra_detail_plural = "Expected available in {wait} seconds."
default_code = "throttled"
def __init__(self, wait=None, detail=None, code=None):
if detail is None:
detail = force_text(self.default_detail)
if wait is not None:
wait = math.ceil(wait)
detail = ' '.join((
detail,
force_text(ungettext(self.extra_detail_singular.format(wait=wait),
self.extra_detail_plural.format(wait=wait),
wait))))
detail = " ".join(
(
detail,
force_text(
ungettext(
self.extra_detail_singular.format(wait=wait),
self.extra_detail_plural.format(wait=wait),
wait,
)
),
)
)
self.wait = wait
super(Throttled, self).__init__(detail, code)
@ -245,9 +247,7 @@ def server_error(request, *args, **kwargs):
"""
Generic 500 error handler.
"""
data = {
'error': 'Server Error (500)'
}
data = {"error": "Server Error (500)"}
return JsonResponse(data, status=status.HTTP_500_INTERNAL_SERVER_ERROR)
@ -255,7 +255,5 @@ def bad_request(request, exception, *args, **kwargs):
"""
Generic 400 error handler.
"""
data = {
'error': 'Bad Request (400)'
}
data = {"error": "Bad Request (400)"}
return JsonResponse(data, status=status.HTTP_400_BAD_REQUEST)

File diff suppressed because it is too large Load Diff

View File

@ -18,9 +18,7 @@ from django.utils.encoding import force_text
from django.utils.translation import ugettext_lazy as _
from rest_framework import RemovedInDRF310Warning
from rest_framework.compat import (
coreapi, coreschema, distinct, is_guardian_installed
)
from rest_framework.compat import coreapi, coreschema, distinct, is_guardian_installed
from rest_framework.settings import api_settings
@ -36,23 +34,22 @@ class BaseFilterBackend(object):
raise NotImplementedError(".filter_queryset() must be overridden.")
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return []
class SearchFilter(BaseFilterBackend):
# The URL query parameter used for the search.
search_param = api_settings.SEARCH_PARAM
template = 'rest_framework/filters/search.html'
lookup_prefixes = {
'^': 'istartswith',
'=': 'iexact',
'@': 'search',
'$': 'iregex',
}
search_title = _('Search')
search_description = _('A search term.')
template = "rest_framework/filters/search.html"
lookup_prefixes = {"^": "istartswith", "=": "iexact", "@": "search", "$": "iregex"}
search_title = _("Search")
search_description = _("A search term.")
def get_search_fields(self, view, request):
"""
@ -60,22 +57,22 @@ class SearchFilter(BaseFilterBackend):
passed to this method. Sub-classes can override this method to
dynamically change the search fields based on request content.
"""
return getattr(view, 'search_fields', None)
return getattr(view, "search_fields", None)
def get_search_terms(self, request):
"""
Search terms are set by a ?search=... query parameter,
and may be comma and/or whitespace delimited.
"""
params = request.query_params.get(self.search_param, '')
return params.replace(',', ' ').split()
params = request.query_params.get(self.search_param, "")
return params.replace(",", " ").split()
def construct_search(self, field_name):
lookup = self.lookup_prefixes.get(field_name[0])
if lookup:
field_name = field_name[1:]
else:
lookup = 'icontains'
lookup = "icontains"
return LOOKUP_SEP.join([field_name, lookup])
def must_call_distinct(self, queryset, search_fields):
@ -87,12 +84,15 @@ class SearchFilter(BaseFilterBackend):
if search_field[0] in self.lookup_prefixes:
search_field = search_field[1:]
# Annotated fields do not need to be distinct
if isinstance(queryset, models.QuerySet) and search_field in queryset.query.annotations:
if (
isinstance(queryset, models.QuerySet)
and search_field in queryset.query.annotations
):
return False
parts = search_field.split(LOOKUP_SEP)
for part in parts:
field = opts.get_field(part)
if hasattr(field, 'get_path_info'):
if hasattr(field, "get_path_info"):
# This field is a relation, update opts to follow the relation
path_info = field.get_path_info()
opts = path_info[-1].to_opts
@ -117,8 +117,7 @@ class SearchFilter(BaseFilterBackend):
conditions = []
for search_term in search_terms:
queries = [
models.Q(**{orm_lookup: search_term})
for orm_lookup in orm_lookups
models.Q(**{orm_lookup: search_term}) for orm_lookup in orm_lookups
]
conditions.append(reduce(operator.or_, queries))
queryset = queryset.filter(reduce(operator.and_, conditions))
@ -132,30 +131,31 @@ class SearchFilter(BaseFilterBackend):
return queryset
def to_html(self, request, queryset, view):
if not getattr(view, 'search_fields', None):
return ''
if not getattr(view, "search_fields", None):
return ""
term = self.get_search_terms(request)
term = term[0] if term else ''
context = {
'param': self.search_param,
'term': term
}
term = term[0] if term else ""
context = {"param": self.search_param, "term": term}
template = loader.get_template(self.template)
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return [
coreapi.Field(
name=self.search_param,
required=False,
location='query',
location="query",
schema=coreschema.String(
title=force_text(self.search_title),
description=force_text(self.search_description)
)
description=force_text(self.search_description),
),
)
]
@ -164,9 +164,9 @@ class OrderingFilter(BaseFilterBackend):
# The URL query parameter used for the ordering.
ordering_param = api_settings.ORDERING_PARAM
ordering_fields = None
ordering_title = _('Ordering')
ordering_description = _('Which field to use when ordering the results.')
template = 'rest_framework/filters/ordering.html'
ordering_title = _("Ordering")
ordering_description = _("Which field to use when ordering the results.")
template = "rest_framework/filters/ordering.html"
def get_ordering(self, request, queryset, view):
"""
@ -178,7 +178,7 @@ class OrderingFilter(BaseFilterBackend):
"""
params = request.query_params.get(self.ordering_param)
if params:
fields = [param.strip() for param in params.split(',')]
fields = [param.strip() for param in params.split(",")]
ordering = self.remove_invalid_fields(queryset, fields, view, request)
if ordering:
return ordering
@ -187,7 +187,7 @@ class OrderingFilter(BaseFilterBackend):
return self.get_default_ordering(view)
def get_default_ordering(self, view):
ordering = getattr(view, 'ordering', None)
ordering = getattr(view, "ordering", None)
if isinstance(ordering, six.string_types):
return (ordering,)
return ordering
@ -195,7 +195,7 @@ class OrderingFilter(BaseFilterBackend):
def get_default_valid_fields(self, queryset, view, context={}):
# If `ordering_fields` is not specified, then we determine a default
# based on the serializer class, if one exists on the view.
if hasattr(view, 'get_serializer_class'):
if hasattr(view, "get_serializer_class"):
try:
serializer_class = view.get_serializer_class()
except AssertionError:
@ -203,7 +203,7 @@ class OrderingFilter(BaseFilterBackend):
# no serializer_class was found
serializer_class = None
else:
serializer_class = getattr(view, 'serializer_class', None)
serializer_class = getattr(view, "serializer_class", None)
if serializer_class is None:
msg = (
@ -214,26 +214,26 @@ class OrderingFilter(BaseFilterBackend):
raise ImproperlyConfigured(msg % self.__class__.__name__)
return [
(field.source.replace('.', '__') or field_name, field.label)
(field.source.replace(".", "__") or field_name, field.label)
for field_name, field in serializer_class(context=context).fields.items()
if not getattr(field, 'write_only', False) and not field.source == '*'
if not getattr(field, "write_only", False) and not field.source == "*"
]
def get_valid_fields(self, queryset, view, context={}):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
valid_fields = getattr(view, "ordering_fields", self.ordering_fields)
if valid_fields is None:
# Default to allowing filtering on serializer fields
return self.get_default_valid_fields(queryset, view, context)
elif valid_fields == '__all__':
elif valid_fields == "__all__":
# View explicitly allows filtering on any model field
valid_fields = [
(field.name, field.verbose_name) for field in queryset.model._meta.fields
(field.name, field.verbose_name)
for field in queryset.model._meta.fields
]
valid_fields += [
(key, key.title().split('__'))
for key in queryset.query.annotations
(key, key.title().split("__")) for key in queryset.query.annotations
]
else:
valid_fields = [
@ -244,8 +244,15 @@ class OrderingFilter(BaseFilterBackend):
return valid_fields
def remove_invalid_fields(self, queryset, fields, view, request):
valid_fields = [item[0] for item in self.get_valid_fields(queryset, view, {'request': request})]
return [term for term in fields if term.lstrip('-') in valid_fields and ORDER_PATTERN.match(term)]
valid_fields = [
item[0]
for item in self.get_valid_fields(queryset, view, {"request": request})
]
return [
term
for term in fields
if term.lstrip("-") in valid_fields and ORDER_PATTERN.match(term)
]
def filter_queryset(self, request, queryset, view):
ordering = self.get_ordering(request, queryset, view)
@ -259,15 +266,11 @@ class OrderingFilter(BaseFilterBackend):
current = self.get_ordering(request, queryset, view)
current = None if not current else current[0]
options = []
context = {
'request': request,
'current': current,
'param': self.ordering_param,
}
context = {"request": request, "current": current, "param": self.ordering_param}
for key, label in self.get_valid_fields(queryset, view, context):
options.append((key, '%s - %s' % (label, _('ascending'))))
options.append(('-' + key, '%s - %s' % (label, _('descending'))))
context['options'] = options
options.append((key, "%s - %s" % (label, _("ascending"))))
options.append(("-" + key, "%s - %s" % (label, _("descending"))))
context["options"] = options
return context
def to_html(self, request, queryset, view):
@ -276,17 +279,21 @@ class OrderingFilter(BaseFilterBackend):
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return [
coreapi.Field(
name=self.ordering_param,
required=False,
location='query',
location="query",
schema=coreschema.String(
title=force_text(self.ordering_title),
description=force_text(self.ordering_description)
)
description=force_text(self.ordering_description),
),
)
]
@ -296,15 +303,19 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):
A filter backend that limits results to those where the requesting user
has read object level permissions.
"""
def __init__(self):
warnings.warn(
"`DjangoObjectPermissionsFilter` has been deprecated and moved to "
"the 3rd-party django-rest-framework-guardian package.",
RemovedInDRF310Warning, stacklevel=2
RemovedInDRF310Warning,
stacklevel=2,
)
assert is_guardian_installed(), 'Using DjangoObjectPermissionsFilter, but django-guardian is not installed'
assert (
is_guardian_installed()
), "Using DjangoObjectPermissionsFilter, but django-guardian is not installed"
perm_format = '%(app_label)s.view_%(model_name)s'
perm_format = "%(app_label)s.view_%(model_name)s"
def filter_queryset(self, request, queryset, view):
# We want to defer this import until run-time, rather than import-time.
@ -317,13 +328,13 @@ class DjangoObjectPermissionsFilter(BaseFilterBackend):
user = request.user
model_cls = queryset.model
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
"app_label": model_cls._meta.app_label,
"model_name": model_cls._meta.model_name,
}
permission = self.perm_format % kwargs
if tuple(guardian_version) >= (1, 3):
# Maintain behavior compatibility with versions prior to 1.3
extra = {'accept_global_perms': False}
extra = {"accept_global_perms": False}
else:
extra = {}
return get_objects_for_user(user, permission, queryset, **extra)

View File

@ -27,6 +27,7 @@ class GenericAPIView(views.APIView):
"""
Base class for all other generic views.
"""
# You'll need to either set these attributes,
# or override `get_queryset()`/`get_serializer_class()`.
# If you are overriding a view method, it is important that you call
@ -38,7 +39,7 @@ class GenericAPIView(views.APIView):
# If you want to use object lookups other than pk, set 'lookup_field'.
# For more complex lookup requirements override `get_object()`.
lookup_field = 'pk'
lookup_field = "pk"
lookup_url_kwarg = None
# The filter backend classes to use for queryset filtering
@ -64,8 +65,7 @@ class GenericAPIView(views.APIView):
"""
assert self.queryset is not None, (
"'%s' should either include a `queryset` attribute, "
"or override the `get_queryset()` method."
% self.__class__.__name__
"or override the `get_queryset()` method." % self.__class__.__name__
)
queryset = self.queryset
@ -88,10 +88,10 @@ class GenericAPIView(views.APIView):
lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field
assert lookup_url_kwarg in self.kwargs, (
'Expected view %s to be called with a URL keyword argument '
"Expected view %s to be called with a URL keyword argument "
'named "%s". Fix your URL conf, or set the `.lookup_field` '
'attribute on the view correctly.' %
(self.__class__.__name__, lookup_url_kwarg)
"attribute on the view correctly."
% (self.__class__.__name__, lookup_url_kwarg)
)
filter_kwargs = {self.lookup_field: self.kwargs[lookup_url_kwarg]}
@ -108,7 +108,7 @@ class GenericAPIView(views.APIView):
deserializing input, and for serializing output.
"""
serializer_class = self.get_serializer_class()
kwargs['context'] = self.get_serializer_context()
kwargs["context"] = self.get_serializer_context()
return serializer_class(*args, **kwargs)
def get_serializer_class(self):
@ -123,8 +123,7 @@ class GenericAPIView(views.APIView):
"""
assert self.serializer_class is not None, (
"'%s' should either include a `serializer_class` attribute, "
"or override the `get_serializer_class()` method."
% self.__class__.__name__
"or override the `get_serializer_class()` method." % self.__class__.__name__
)
return self.serializer_class
@ -133,11 +132,7 @@ class GenericAPIView(views.APIView):
"""
Extra context provided to the serializer class.
"""
return {
'request': self.request,
'format': self.format_kwarg,
'view': self
}
return {"request": self.request, "format": self.format_kwarg, "view": self}
def filter_queryset(self, queryset):
"""
@ -157,7 +152,7 @@ class GenericAPIView(views.APIView):
"""
The paginator instance associated with the view, or `None`.
"""
if not hasattr(self, '_paginator'):
if not hasattr(self, "_paginator"):
if self.pagination_class is None:
self._paginator = None
else:
@ -183,47 +178,48 @@ class GenericAPIView(views.APIView):
# Concrete view classes that provide method handlers
# by composing the mixin classes with the base view.
class CreateAPIView(mixins.CreateModelMixin,
GenericAPIView):
class CreateAPIView(mixins.CreateModelMixin, GenericAPIView):
"""
Concrete view for creating a model instance.
"""
def post(self, request, *args, **kwargs):
return self.create(request, *args, **kwargs)
class ListAPIView(mixins.ListModelMixin,
GenericAPIView):
class ListAPIView(mixins.ListModelMixin, GenericAPIView):
"""
Concrete view for listing a queryset.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
class RetrieveAPIView(mixins.RetrieveModelMixin,
GenericAPIView):
class RetrieveAPIView(mixins.RetrieveModelMixin, GenericAPIView):
"""
Concrete view for retrieving a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
class DestroyAPIView(mixins.DestroyModelMixin,
GenericAPIView):
class DestroyAPIView(mixins.DestroyModelMixin, GenericAPIView):
"""
Concrete view for deleting a model instance.
"""
def delete(self, request, *args, **kwargs):
return self.destroy(request, *args, **kwargs)
class UpdateAPIView(mixins.UpdateModelMixin,
GenericAPIView):
class UpdateAPIView(mixins.UpdateModelMixin, GenericAPIView):
"""
Concrete view for updating a model instance.
"""
def put(self, request, *args, **kwargs):
return self.update(request, *args, **kwargs)
@ -231,12 +227,11 @@ class UpdateAPIView(mixins.UpdateModelMixin,
return self.partial_update(request, *args, **kwargs)
class ListCreateAPIView(mixins.ListModelMixin,
mixins.CreateModelMixin,
GenericAPIView):
class ListCreateAPIView(mixins.ListModelMixin, mixins.CreateModelMixin, GenericAPIView):
"""
Concrete view for listing a queryset or creating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.list(request, *args, **kwargs)
@ -244,12 +239,13 @@ class ListCreateAPIView(mixins.ListModelMixin,
return self.create(request, *args, **kwargs)
class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
GenericAPIView):
class RetrieveUpdateAPIView(
mixins.RetrieveModelMixin, mixins.UpdateModelMixin, GenericAPIView
):
"""
Concrete view for retrieving, updating a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@ -260,12 +256,13 @@ class RetrieveUpdateAPIView(mixins.RetrieveModelMixin,
return self.partial_update(request, *args, **kwargs)
class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
mixins.DestroyModelMixin,
GenericAPIView):
class RetrieveDestroyAPIView(
mixins.RetrieveModelMixin, mixins.DestroyModelMixin, GenericAPIView
):
"""
Concrete view for retrieving or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)
@ -273,13 +270,16 @@ class RetrieveDestroyAPIView(mixins.RetrieveModelMixin,
return self.destroy(request, *args, **kwargs)
class RetrieveUpdateDestroyAPIView(mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
GenericAPIView):
class RetrieveUpdateDestroyAPIView(
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
GenericAPIView,
):
"""
Concrete view for retrieving, updating or deleting a model instance.
"""
def get(self, request, *args, **kwargs):
return self.retrieve(request, *args, **kwargs)

View File

@ -2,7 +2,9 @@ from django.core.management.base import BaseCommand
from rest_framework.compat import coreapi
from rest_framework.renderers import (
CoreJSONRenderer, JSONOpenAPIRenderer, OpenAPIRenderer
CoreJSONRenderer,
JSONOpenAPIRenderer,
OpenAPIRenderer,
)
from rest_framework.schemas.generators import SchemaGenerator
@ -11,31 +13,37 @@ class Command(BaseCommand):
help = "Generates configured API schema for project."
def add_arguments(self, parser):
parser.add_argument('--title', dest="title", default=None, type=str)
parser.add_argument('--url', dest="url", default=None, type=str)
parser.add_argument('--description', dest="description", default=None, type=str)
parser.add_argument('--format', dest="format", choices=['openapi', 'openapi-json', 'corejson'], default='openapi', type=str)
parser.add_argument("--title", dest="title", default=None, type=str)
parser.add_argument("--url", dest="url", default=None, type=str)
parser.add_argument("--description", dest="description", default=None, type=str)
parser.add_argument(
"--format",
dest="format",
choices=["openapi", "openapi-json", "corejson"],
default="openapi",
type=str,
)
def handle(self, *args, **options):
assert coreapi is not None, 'coreapi must be installed.'
assert coreapi is not None, "coreapi must be installed."
generator = SchemaGenerator(
url=options['url'],
title=options['title'],
description=options['description']
url=options["url"],
title=options["title"],
description=options["description"],
)
schema = generator.get_schema(request=None, public=True)
renderer = self.get_renderer(options['format'])
renderer = self.get_renderer(options["format"])
output = renderer.render(schema, renderer_context={})
self.stdout.write(output.decode('utf-8'))
self.stdout.write(output.decode("utf-8"))
def get_renderer(self, format):
renderer_cls = {
'corejson': CoreJSONRenderer,
'openapi': OpenAPIRenderer,
'openapi-json': JSONOpenAPIRenderer,
"corejson": CoreJSONRenderer,
"openapi": OpenAPIRenderer,
"openapi-json": JSONOpenAPIRenderer,
}[format]
return renderer_cls()

View File

@ -35,41 +35,46 @@ class SimpleMetadata(BaseMetadata):
There are not any formalized standards for `OPTIONS` responses
for us to base this on.
"""
label_lookup = ClassLookupDict({
serializers.Field: 'field',
serializers.BooleanField: 'boolean',
serializers.NullBooleanField: 'boolean',
serializers.CharField: 'string',
serializers.UUIDField: 'string',
serializers.URLField: 'url',
serializers.EmailField: 'email',
serializers.RegexField: 'regex',
serializers.SlugField: 'slug',
serializers.IntegerField: 'integer',
serializers.FloatField: 'float',
serializers.DecimalField: 'decimal',
serializers.DateField: 'date',
serializers.DateTimeField: 'datetime',
serializers.TimeField: 'time',
serializers.ChoiceField: 'choice',
serializers.MultipleChoiceField: 'multiple choice',
serializers.FileField: 'file upload',
serializers.ImageField: 'image upload',
serializers.ListField: 'list',
serializers.DictField: 'nested object',
serializers.Serializer: 'nested object',
})
label_lookup = ClassLookupDict(
{
serializers.Field: "field",
serializers.BooleanField: "boolean",
serializers.NullBooleanField: "boolean",
serializers.CharField: "string",
serializers.UUIDField: "string",
serializers.URLField: "url",
serializers.EmailField: "email",
serializers.RegexField: "regex",
serializers.SlugField: "slug",
serializers.IntegerField: "integer",
serializers.FloatField: "float",
serializers.DecimalField: "decimal",
serializers.DateField: "date",
serializers.DateTimeField: "datetime",
serializers.TimeField: "time",
serializers.ChoiceField: "choice",
serializers.MultipleChoiceField: "multiple choice",
serializers.FileField: "file upload",
serializers.ImageField: "image upload",
serializers.ListField: "list",
serializers.DictField: "nested object",
serializers.Serializer: "nested object",
}
)
def determine_metadata(self, request, view):
metadata = OrderedDict()
metadata['name'] = view.get_view_name()
metadata['description'] = view.get_view_description()
metadata['renders'] = [renderer.media_type for renderer in view.renderer_classes]
metadata['parses'] = [parser.media_type for parser in view.parser_classes]
if hasattr(view, 'get_serializer'):
metadata["name"] = view.get_view_name()
metadata["description"] = view.get_view_description()
metadata["renders"] = [
renderer.media_type for renderer in view.renderer_classes
]
metadata["parses"] = [parser.media_type for parser in view.parser_classes]
if hasattr(view, "get_serializer"):
actions = self.determine_actions(request, view)
if actions:
metadata['actions'] = actions
metadata["actions"] = actions
return metadata
def determine_actions(self, request, view):
@ -78,14 +83,14 @@ class SimpleMetadata(BaseMetadata):
the fields that are accepted for 'PUT' and 'POST' methods.
"""
actions = {}
for method in {'PUT', 'POST'} & set(view.allowed_methods):
for method in {"PUT", "POST"} & set(view.allowed_methods):
view.request = clone_request(request, method)
try:
# Test global permissions
if hasattr(view, 'check_permissions'):
if hasattr(view, "check_permissions"):
view.check_permissions(view.request)
# Test object permissions
if method == 'PUT' and hasattr(view, 'get_object'):
if method == "PUT" and hasattr(view, "get_object"):
view.get_object()
except (exceptions.APIException, PermissionDenied, Http404):
pass
@ -104,15 +109,17 @@ class SimpleMetadata(BaseMetadata):
Given an instance of a serializer, return a dictionary of metadata
about its fields.
"""
if hasattr(serializer, 'child'):
if hasattr(serializer, "child"):
# If this is a `ListSerializer` then we want to examine the
# underlying child serializer instance instead.
serializer = serializer.child
return OrderedDict([
(field_name, self.get_field_info(field))
for field_name, field in serializer.fields.items()
if not isinstance(field, serializers.HiddenField)
])
return OrderedDict(
[
(field_name, self.get_field_info(field))
for field_name, field in serializer.fields.items()
if not isinstance(field, serializers.HiddenField)
]
)
def get_field_info(self, field):
"""
@ -120,32 +127,40 @@ class SimpleMetadata(BaseMetadata):
of metadata about it.
"""
field_info = OrderedDict()
field_info['type'] = self.label_lookup[field]
field_info['required'] = getattr(field, 'required', False)
field_info["type"] = self.label_lookup[field]
field_info["required"] = getattr(field, "required", False)
attrs = [
'read_only', 'label', 'help_text',
'min_length', 'max_length',
'min_value', 'max_value'
"read_only",
"label",
"help_text",
"min_length",
"max_length",
"min_value",
"max_value",
]
for attr in attrs:
value = getattr(field, attr, None)
if value is not None and value != '':
if value is not None and value != "":
field_info[attr] = force_text(value, strings_only=True)
if getattr(field, 'child', None):
field_info['child'] = self.get_field_info(field.child)
elif getattr(field, 'fields', None):
field_info['children'] = self.get_serializer_info(field)
if getattr(field, "child", None):
field_info["child"] = self.get_field_info(field.child)
elif getattr(field, "fields", None):
field_info["children"] = self.get_serializer_info(field)
if (not field_info.get('read_only') and
not isinstance(field, (serializers.RelatedField, serializers.ManyRelatedField)) and
hasattr(field, 'choices')):
field_info['choices'] = [
if (
not field_info.get("read_only")
and not isinstance(
field, (serializers.RelatedField, serializers.ManyRelatedField)
)
and hasattr(field, "choices")
):
field_info["choices"] = [
{
'value': choice_value,
'display_name': force_text(choice_name, strings_only=True)
"value": choice_value,
"display_name": force_text(choice_name, strings_only=True),
}
for choice_value, choice_name in field.choices.items()
]

View File

@ -15,19 +15,22 @@ class CreateModelMixin(object):
"""
Create a model instance.
"""
def create(self, request, *args, **kwargs):
serializer = self.get_serializer(data=request.data)
serializer.is_valid(raise_exception=True)
self.perform_create(serializer)
headers = self.get_success_headers(serializer.data)
return Response(serializer.data, status=status.HTTP_201_CREATED, headers=headers)
return Response(
serializer.data, status=status.HTTP_201_CREATED, headers=headers
)
def perform_create(self, serializer):
serializer.save()
def get_success_headers(self, data):
try:
return {'Location': str(data[api_settings.URL_FIELD_NAME])}
return {"Location": str(data[api_settings.URL_FIELD_NAME])}
except (TypeError, KeyError):
return {}
@ -36,6 +39,7 @@ class ListModelMixin(object):
"""
List a queryset.
"""
def list(self, request, *args, **kwargs):
queryset = self.filter_queryset(self.get_queryset())
@ -52,6 +56,7 @@ class RetrieveModelMixin(object):
"""
Retrieve a model instance.
"""
def retrieve(self, request, *args, **kwargs):
instance = self.get_object()
serializer = self.get_serializer(instance)
@ -62,14 +67,15 @@ class UpdateModelMixin(object):
"""
Update a model instance.
"""
def update(self, request, *args, **kwargs):
partial = kwargs.pop('partial', False)
partial = kwargs.pop("partial", False)
instance = self.get_object()
serializer = self.get_serializer(instance, data=request.data, partial=partial)
serializer.is_valid(raise_exception=True)
self.perform_update(serializer)
if getattr(instance, '_prefetched_objects_cache', None):
if getattr(instance, "_prefetched_objects_cache", None):
# If 'prefetch_related' has been applied to a queryset, we need to
# forcibly invalidate the prefetch cache on the instance.
instance._prefetched_objects_cache = {}
@ -80,7 +86,7 @@ class UpdateModelMixin(object):
serializer.save()
def partial_update(self, request, *args, **kwargs):
kwargs['partial'] = True
kwargs["partial"] = True
return self.update(request, *args, **kwargs)
@ -88,6 +94,7 @@ class DestroyModelMixin(object):
"""
Destroy a model instance.
"""
def destroy(self, request, *args, **kwargs):
instance = self.get_object()
self.perform_destroy(instance)

View File

@ -9,16 +9,18 @@ from django.http import Http404
from rest_framework import HTTP_HEADER_ENCODING, exceptions
from rest_framework.settings import api_settings
from rest_framework.utils.mediatypes import (
_MediaType, media_type_matches, order_by_precedence
_MediaType,
media_type_matches,
order_by_precedence,
)
class BaseContentNegotiation(object):
def select_parser(self, request, parsers):
raise NotImplementedError('.select_parser() must be implemented')
raise NotImplementedError(".select_parser() must be implemented")
def select_renderer(self, request, renderers, format_suffix=None):
raise NotImplementedError('.select_renderer() must be implemented')
raise NotImplementedError(".select_renderer() must be implemented")
class DefaultContentNegotiation(BaseContentNegotiation):
@ -59,16 +61,20 @@ class DefaultContentNegotiation(BaseContentNegotiation):
# Return the most specific media type as accepted.
media_type_wrapper = _MediaType(media_type)
if (
_MediaType(renderer.media_type).precedence >
media_type_wrapper.precedence
_MediaType(renderer.media_type).precedence
> media_type_wrapper.precedence
):
# Eg client requests '*/*'
# Accepted media type is 'application/json'
full_media_type = ';'.join(
(renderer.media_type,) +
tuple('{0}={1}'.format(
key, value.decode(HTTP_HEADER_ENCODING))
for key, value in media_type_wrapper.params.items()))
full_media_type = ";".join(
(renderer.media_type,)
+ tuple(
"{0}={1}".format(
key, value.decode(HTTP_HEADER_ENCODING)
)
for key, value in media_type_wrapper.params.items()
)
)
return renderer, full_media_type
else:
# Eg client requests 'application/json; indent=8'
@ -82,8 +88,7 @@ class DefaultContentNegotiation(BaseContentNegotiation):
If there is a '.json' style format suffix, filter the renderers
so that we only negotiation against those that accept that format.
"""
renderers = [renderer for renderer in renderers
if renderer.format == format]
renderers = [renderer for renderer in renderers if renderer.format == format]
if not renderers:
raise Http404
return renderers
@ -93,5 +98,5 @@ class DefaultContentNegotiation(BaseContentNegotiation):
Given the incoming request, return a tokenized list of media
type strings.
"""
header = request.META.get('HTTP_ACCEPT', '*/*')
return [token.strip() for token in header.split(',')]
header = request.META.get("HTTP_ACCEPT", "*/*")
return [token.strip() for token in header.split(",")]

View File

@ -8,8 +8,7 @@ from __future__ import unicode_literals
from base64 import b64decode, b64encode
from collections import OrderedDict, namedtuple
from django.core.paginator import InvalidPage
from django.core.paginator import Paginator as DjangoPaginator
from django.core.paginator import InvalidPage, Paginator as DjangoPaginator
from django.template import loader
from django.utils import six
from django.utils.encoding import force_text
@ -83,10 +82,7 @@ def _get_displayed_page_numbers(current, final):
included.add(final - 2)
# Now sort the page numbers and drop anything outside the limits.
included = [
idx for idx in sorted(list(included))
if 0 < idx <= final
]
included = [idx for idx in sorted(list(included)) if 0 < idx <= final]
# Finally insert any `...` breaks
if current > 4:
@ -110,7 +106,7 @@ def _get_page_links(page_numbers, current, url_func):
url=url_func(page_number),
number=page_number,
is_active=(page_number == current),
is_break=False
is_break=False,
)
page_links.append(page_link)
return page_links
@ -121,14 +117,15 @@ def _reverse_ordering(ordering_tuple):
Given an order_by tuple such as `('-created', 'uuid')` reverse the
ordering and return a new tuple, eg. `('created', '-uuid')`.
"""
def invert(x):
return x[1:] if x.startswith('-') else '-' + x
return x[1:] if x.startswith("-") else "-" + x
return tuple([invert(item) for item in ordering_tuple])
Cursor = namedtuple('Cursor', ['offset', 'reverse', 'position'])
PageLink = namedtuple('PageLink', ['url', 'number', 'is_active', 'is_break'])
Cursor = namedtuple("Cursor", ["offset", "reverse", "position"])
PageLink = namedtuple("PageLink", ["url", "number", "is_active", "is_break"])
PAGE_BREAK = PageLink(url=None, number=None, is_active=False, is_break=True)
@ -137,19 +134,23 @@ class BasePagination(object):
display_page_controls = False
def paginate_queryset(self, queryset, request, view=None): # pragma: no cover
raise NotImplementedError('paginate_queryset() must be implemented.')
raise NotImplementedError("paginate_queryset() must be implemented.")
def get_paginated_response(self, data): # pragma: no cover
raise NotImplementedError('get_paginated_response() must be implemented.')
raise NotImplementedError("get_paginated_response() must be implemented.")
def to_html(self): # pragma: no cover
raise NotImplementedError('to_html() must be implemented to display page controls.')
raise NotImplementedError(
"to_html() must be implemented to display page controls."
)
def get_results(self, data):
return data['results']
return data["results"]
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
return []
@ -161,6 +162,7 @@ class PageNumberPagination(BasePagination):
http://api.example.org/accounts/?page=4
http://api.example.org/accounts/?page=4&page_size=100
"""
# The default page size.
# Defaults to `None`, meaning pagination is disabled.
page_size = api_settings.PAGE_SIZE
@ -168,23 +170,23 @@ class PageNumberPagination(BasePagination):
django_paginator_class = DjangoPaginator
# Client can control the page using this query parameter.
page_query_param = 'page'
page_query_description = _('A page number within the paginated result set.')
page_query_param = "page"
page_query_description = _("A page number within the paginated result set.")
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
page_size_query_description = _('Number of results to return per page.')
page_size_query_description = _("Number of results to return per page.")
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
max_page_size = None
last_page_strings = ('last',)
last_page_strings = ("last",)
template = 'rest_framework/pagination/numbers.html'
template = "rest_framework/pagination/numbers.html"
invalid_page_message = _('Invalid page.')
invalid_page_message = _("Invalid page.")
def paginate_queryset(self, queryset, request, view=None):
"""
@ -216,12 +218,16 @@ class PageNumberPagination(BasePagination):
return list(self.page)
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.page.paginator.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
return Response(
OrderedDict(
[
("count", self.page.paginator.count),
("next", self.get_next_link()),
("previous", self.get_previous_link()),
("results", data),
]
)
)
def get_page_size(self, request):
if self.page_size_query_param:
@ -229,7 +235,7 @@ class PageNumberPagination(BasePagination):
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
cutoff=self.max_page_size
cutoff=self.max_page_size,
)
except (KeyError, ValueError):
pass
@ -267,9 +273,9 @@ class PageNumberPagination(BasePagination):
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link(),
'page_links': page_links
"previous_url": self.get_previous_link(),
"next_url": self.get_next_link(),
"page_links": page_links,
}
def to_html(self):
@ -278,17 +284,20 @@ class PageNumberPagination(BasePagination):
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
fields = [
coreapi.Field(
name=self.page_query_param,
required=False,
location='query',
location="query",
schema=coreschema.Integer(
title='Page',
description=force_text(self.page_query_description)
)
title="Page", description=force_text(self.page_query_description)
),
)
]
if self.page_size_query_param is not None:
@ -296,11 +305,11 @@ class PageNumberPagination(BasePagination):
coreapi.Field(
name=self.page_size_query_param,
required=False,
location='query',
location="query",
schema=coreschema.Integer(
title='Page size',
description=force_text(self.page_size_query_description)
)
title="Page size",
description=force_text(self.page_size_query_description),
),
)
)
return fields
@ -313,13 +322,14 @@ class LimitOffsetPagination(BasePagination):
http://api.example.org/accounts/?limit=100
http://api.example.org/accounts/?offset=400&limit=100
"""
default_limit = api_settings.PAGE_SIZE
limit_query_param = 'limit'
limit_query_description = _('Number of results to return per page.')
offset_query_param = 'offset'
offset_query_description = _('The initial index from which to return the results.')
limit_query_param = "limit"
limit_query_description = _("Number of results to return per page.")
offset_query_param = "offset"
offset_query_description = _("The initial index from which to return the results.")
max_limit = None
template = 'rest_framework/pagination/numbers.html'
template = "rest_framework/pagination/numbers.html"
def paginate_queryset(self, queryset, request, view=None):
self.count = self.get_count(queryset)
@ -334,15 +344,19 @@ class LimitOffsetPagination(BasePagination):
if self.count == 0 or self.offset > self.count:
return []
return list(queryset[self.offset:self.offset + self.limit])
return list(queryset[self.offset : self.offset + self.limit])
def get_paginated_response(self, data):
return Response(OrderedDict([
('count', self.count),
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
return Response(
OrderedDict(
[
("count", self.count),
("next", self.get_next_link()),
("previous", self.get_previous_link()),
("results", data),
]
)
)
def get_limit(self, request):
if self.limit_query_param:
@ -350,7 +364,7 @@ class LimitOffsetPagination(BasePagination):
return _positive_int(
request.query_params[self.limit_query_param],
strict=True,
cutoff=self.max_limit
cutoff=self.max_limit,
)
except (KeyError, ValueError):
pass
@ -359,9 +373,7 @@ class LimitOffsetPagination(BasePagination):
def get_offset(self, request):
try:
return _positive_int(
request.query_params[self.offset_query_param],
)
return _positive_int(request.query_params[self.offset_query_param])
except (KeyError, ValueError):
return 0
@ -399,10 +411,9 @@ class LimitOffsetPagination(BasePagination):
# plus the number of pages up to the current offset.
# When offset is not strictly divisible by the limit then we may
# end up introducing an extra page as an artifact.
final = (
_divide_with_ceil(self.count - self.offset, self.limit) +
_divide_with_ceil(self.offset, self.limit)
)
final = _divide_with_ceil(
self.count - self.offset, self.limit
) + _divide_with_ceil(self.offset, self.limit)
if final < 1:
final = 1
@ -424,9 +435,9 @@ class LimitOffsetPagination(BasePagination):
page_links = _get_page_links(page_numbers, current, page_number_to_url)
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link(),
'page_links': page_links
"previous_url": self.get_previous_link(),
"next_url": self.get_next_link(),
"page_links": page_links,
}
def to_html(self):
@ -435,27 +446,30 @@ class LimitOffsetPagination(BasePagination):
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
return [
coreapi.Field(
name=self.limit_query_param,
required=False,
location='query',
location="query",
schema=coreschema.Integer(
title='Limit',
description=force_text(self.limit_query_description)
)
title="Limit", description=force_text(self.limit_query_description)
),
),
coreapi.Field(
name=self.offset_query_param,
required=False,
location='query',
location="query",
schema=coreschema.Integer(
title='Offset',
description=force_text(self.offset_query_description)
)
)
title="Offset",
description=force_text(self.offset_query_description),
),
),
]
def get_count(self, queryset):
@ -474,17 +488,18 @@ class CursorPagination(BasePagination):
For an overview of the position/offset style we use, see this post:
https://cra.mr/2011/03/08/building-cursors-for-the-disqus-api
"""
cursor_query_param = 'cursor'
cursor_query_description = _('The pagination cursor value.')
cursor_query_param = "cursor"
cursor_query_description = _("The pagination cursor value.")
page_size = api_settings.PAGE_SIZE
invalid_cursor_message = _('Invalid cursor')
ordering = '-created'
template = 'rest_framework/pagination/previous_and_next.html'
invalid_cursor_message = _("Invalid cursor")
ordering = "-created"
template = "rest_framework/pagination/previous_and_next.html"
# Client can control the page size using this query parameter.
# Default is 'None'. Set to eg 'page_size' to enable usage.
page_size_query_param = None
page_size_query_description = _('Number of results to return per page.')
page_size_query_description = _("Number of results to return per page.")
# Set to an integer to limit the maximum page size the client may request.
# Only relevant if 'page_size_query_param' has also been set.
@ -519,27 +534,29 @@ class CursorPagination(BasePagination):
# If we have a cursor with a fixed position then filter by that.
if current_position is not None:
order = self.ordering[0]
is_reversed = order.startswith('-')
order_attr = order.lstrip('-')
is_reversed = order.startswith("-")
order_attr = order.lstrip("-")
# Test for: (cursor reversed) XOR (queryset reversed)
if self.cursor.reverse != is_reversed:
kwargs = {order_attr + '__lt': current_position}
kwargs = {order_attr + "__lt": current_position}
else:
kwargs = {order_attr + '__gt': current_position}
kwargs = {order_attr + "__gt": current_position}
queryset = queryset.filter(**kwargs)
# If we have an offset cursor then offset the entire page by that amount.
# We also always fetch an extra item in order to determine if there is a
# page following on from this one.
results = list(queryset[offset:offset + self.page_size + 1])
self.page = list(results[:self.page_size])
results = list(queryset[offset : offset + self.page_size + 1])
self.page = list(results[: self.page_size])
# Determine the position of the final item following the page.
if len(results) > len(self.page):
has_following_position = True
following_position = self._get_position_from_instance(results[-1], self.ordering)
following_position = self._get_position_from_instance(
results[-1], self.ordering
)
else:
has_following_position = False
following_position = None
@ -578,7 +595,7 @@ class CursorPagination(BasePagination):
return _positive_int(
request.query_params[self.page_size_query_param],
strict=True,
cutoff=self.max_page_size
cutoff=self.max_page_size,
)
except (KeyError, ValueError):
pass
@ -686,8 +703,9 @@ class CursorPagination(BasePagination):
Return a tuple of strings, that may be used in an `order_by` method.
"""
ordering_filters = [
filter_cls for filter_cls in getattr(view, 'filter_backends', [])
if hasattr(filter_cls, 'get_ordering')
filter_cls
for filter_cls in getattr(view, "filter_backends", [])
if hasattr(filter_cls, "get_ordering")
]
if ordering_filters:
@ -697,29 +715,27 @@ class CursorPagination(BasePagination):
filter_instance = filter_cls()
ordering = filter_instance.get_ordering(request, queryset, view)
assert ordering is not None, (
'Using cursor pagination, but filter class {filter_cls} '
'returned a `None` ordering.'.format(
filter_cls=filter_cls.__name__
)
"Using cursor pagination, but filter class {filter_cls} "
"returned a `None` ordering.".format(filter_cls=filter_cls.__name__)
)
else:
# The default case is to check for an `ordering` attribute
# on this pagination instance.
ordering = self.ordering
assert ordering is not None, (
'Using cursor pagination, but no ordering attribute was declared '
'on the pagination class.'
"Using cursor pagination, but no ordering attribute was declared "
"on the pagination class."
)
assert '__' not in ordering, (
'Cursor pagination does not support double underscore lookups '
'for orderings. Orderings should be an unchanging, unique or '
assert "__" not in ordering, (
"Cursor pagination does not support double underscore lookups "
"for orderings. Orderings should be an unchanging, unique or "
'nearly-unique field on the model, such as "-created" or "pk".'
)
assert isinstance(ordering, (six.string_types, list, tuple)), (
'Invalid ordering. Expected string or tuple, but got {type}'.format(
type=type(ordering).__name__
)
assert isinstance(
ordering, (six.string_types, list, tuple)
), "Invalid ordering. Expected string or tuple, but got {type}".format(
type=type(ordering).__name__
)
if isinstance(ordering, six.string_types):
@ -736,16 +752,16 @@ class CursorPagination(BasePagination):
return None
try:
querystring = b64decode(encoded.encode('ascii')).decode('ascii')
querystring = b64decode(encoded.encode("ascii")).decode("ascii")
tokens = urlparse.parse_qs(querystring, keep_blank_values=True)
offset = tokens.get('o', ['0'])[0]
offset = tokens.get("o", ["0"])[0]
offset = _positive_int(offset, cutoff=self.offset_cutoff)
reverse = tokens.get('r', ['0'])[0]
reverse = tokens.get("r", ["0"])[0]
reverse = bool(int(reverse))
position = tokens.get('p', [None])[0]
position = tokens.get("p", [None])[0]
except (TypeError, ValueError):
raise NotFound(self.invalid_cursor_message)
@ -757,18 +773,18 @@ class CursorPagination(BasePagination):
"""
tokens = {}
if cursor.offset != 0:
tokens['o'] = str(cursor.offset)
tokens["o"] = str(cursor.offset)
if cursor.reverse:
tokens['r'] = '1'
tokens["r"] = "1"
if cursor.position is not None:
tokens['p'] = cursor.position
tokens["p"] = cursor.position
querystring = urlparse.urlencode(tokens, doseq=True)
encoded = b64encode(querystring.encode('ascii')).decode('ascii')
encoded = b64encode(querystring.encode("ascii")).decode("ascii")
return replace_query_param(self.base_url, self.cursor_query_param, encoded)
def _get_position_from_instance(self, instance, ordering):
field_name = ordering[0].lstrip('-')
field_name = ordering[0].lstrip("-")
if isinstance(instance, dict):
attr = instance[field_name]
else:
@ -776,16 +792,20 @@ class CursorPagination(BasePagination):
return six.text_type(attr)
def get_paginated_response(self, data):
return Response(OrderedDict([
('next', self.get_next_link()),
('previous', self.get_previous_link()),
('results', data)
]))
return Response(
OrderedDict(
[
("next", self.get_next_link()),
("previous", self.get_previous_link()),
("results", data),
]
)
)
def get_html_context(self):
return {
'previous_url': self.get_previous_link(),
'next_url': self.get_next_link()
"previous_url": self.get_previous_link(),
"next_url": self.get_next_link(),
}
def to_html(self):
@ -794,17 +814,21 @@ class CursorPagination(BasePagination):
return template.render(context)
def get_schema_fields(self, view):
assert coreapi is not None, 'coreapi must be installed to use `get_schema_fields()`'
assert coreschema is not None, 'coreschema must be installed to use `get_schema_fields()`'
assert (
coreapi is not None
), "coreapi must be installed to use `get_schema_fields()`"
assert (
coreschema is not None
), "coreschema must be installed to use `get_schema_fields()`"
fields = [
coreapi.Field(
name=self.cursor_query_param,
required=False,
location='query',
location="query",
schema=coreschema.String(
title='Cursor',
description=force_text(self.cursor_query_description)
)
title="Cursor",
description=force_text(self.cursor_query_description),
),
)
]
if self.page_size_query_param is not None:
@ -812,11 +836,11 @@ class CursorPagination(BasePagination):
coreapi.Field(
name=self.page_size_query_param,
required=False,
location='query',
location="query",
schema=coreschema.Integer(
title='Page size',
description=force_text(self.page_size_query_description)
)
title="Page size",
description=force_text(self.page_size_query_description),
),
)
)
return fields

View File

@ -11,10 +11,12 @@ import codecs
from django.conf import settings
from django.core.files.uploadhandler import StopFutureHandlers
from django.http import QueryDict
from django.http.multipartparser import ChunkIter
from django.http.multipartparser import \
MultiPartParser as DjangoMultiPartParser
from django.http.multipartparser import MultiPartParserError, parse_header
from django.http.multipartparser import (
ChunkIter,
MultiPartParser as DjangoMultiPartParser,
MultiPartParserError,
parse_header,
)
from django.utils import six
from django.utils.encoding import force_text
from django.utils.six.moves.urllib import parse as urlparse
@ -36,6 +38,7 @@ class BaseParser(object):
All parsers should extend `BaseParser`, specifying a `media_type`
attribute, and overriding the `.parse()` method.
"""
media_type = None
def parse(self, stream, media_type=None, parser_context=None):
@ -51,7 +54,8 @@ class JSONParser(BaseParser):
"""
Parses JSON-serialized data.
"""
media_type = 'application/json'
media_type = "application/json"
renderer_class = renderers.JSONRenderer
strict = api_settings.STRICT_JSON
@ -60,21 +64,22 @@ class JSONParser(BaseParser):
Parses the incoming bytestream as JSON and returns the resulting data.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
try:
decoded_stream = codecs.getreader(encoding)(stream)
parse_constant = json.strict_constant if self.strict else None
return json.load(decoded_stream, parse_constant=parse_constant)
except ValueError as exc:
raise ParseError('JSON parse error - %s' % six.text_type(exc))
raise ParseError("JSON parse error - %s" % six.text_type(exc))
class FormParser(BaseParser):
"""
Parser for form data.
"""
media_type = 'application/x-www-form-urlencoded'
media_type = "application/x-www-form-urlencoded"
def parse(self, stream, media_type=None, parser_context=None):
"""
@ -82,7 +87,7 @@ class FormParser(BaseParser):
and returns the resulting QueryDict.
"""
parser_context = parser_context or {}
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
data = QueryDict(stream.read(), encoding=encoding)
return data
@ -91,7 +96,8 @@ class MultiPartParser(BaseParser):
"""
Parser for multipart form data, which may include file data.
"""
media_type = 'multipart/form-data'
media_type = "multipart/form-data"
def parse(self, stream, media_type=None, parser_context=None):
"""
@ -102,10 +108,10 @@ class MultiPartParser(BaseParser):
`.files` will be a `QueryDict` containing all the form files.
"""
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
request = parser_context["request"]
encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
meta = request.META.copy()
meta['CONTENT_TYPE'] = media_type
meta["CONTENT_TYPE"] = media_type
upload_handlers = request.upload_handlers
try:
@ -113,17 +119,18 @@ class MultiPartParser(BaseParser):
data, files = parser.parse()
return DataAndFiles(data, files)
except MultiPartParserError as exc:
raise ParseError('Multipart form parse error - %s' % six.text_type(exc))
raise ParseError("Multipart form parse error - %s" % six.text_type(exc))
class FileUploadParser(BaseParser):
"""
Parser for file upload data.
"""
media_type = '*/*'
media_type = "*/*"
errors = {
'unhandled': 'FileUpload parse error - none of upload handlers can handle the stream',
'no_filename': 'Missing filename. Request should include a Content-Disposition header with a filename parameter.',
"unhandled": "FileUpload parse error - none of upload handlers can handle the stream",
"no_filename": "Missing filename. Request should include a Content-Disposition header with a filename parameter.",
}
def parse(self, stream, media_type=None, parser_context=None):
@ -135,34 +142,32 @@ class FileUploadParser(BaseParser):
`.files` will be a `QueryDict` containing one 'file' element.
"""
parser_context = parser_context or {}
request = parser_context['request']
encoding = parser_context.get('encoding', settings.DEFAULT_CHARSET)
request = parser_context["request"]
encoding = parser_context.get("encoding", settings.DEFAULT_CHARSET)
meta = request.META
upload_handlers = request.upload_handlers
filename = self.get_filename(stream, media_type, parser_context)
if not filename:
raise ParseError(self.errors['no_filename'])
raise ParseError(self.errors["no_filename"])
# Note that this code is extracted from Django's handling of
# file uploads in MultiPartParser.
content_type = meta.get('HTTP_CONTENT_TYPE',
meta.get('CONTENT_TYPE', ''))
content_type = meta.get("HTTP_CONTENT_TYPE", meta.get("CONTENT_TYPE", ""))
try:
content_length = int(meta.get('HTTP_CONTENT_LENGTH',
meta.get('CONTENT_LENGTH', 0)))
content_length = int(
meta.get("HTTP_CONTENT_LENGTH", meta.get("CONTENT_LENGTH", 0))
)
except (ValueError, TypeError):
content_length = None
# See if the handler will want to take care of the parsing.
for handler in upload_handlers:
result = handler.handle_raw_input(stream,
meta,
content_length,
None,
encoding)
result = handler.handle_raw_input(
stream, meta, content_length, None, encoding
)
if result is not None:
return DataAndFiles({}, {'file': result[1]})
return DataAndFiles({}, {"file": result[1]})
# This is the standard case.
possible_sizes = [x.chunk_size for x in upload_handlers if x.chunk_size]
@ -172,10 +177,9 @@ class FileUploadParser(BaseParser):
for index, handler in enumerate(upload_handlers):
try:
handler.new_file(None, filename, content_type,
content_length, encoding)
handler.new_file(None, filename, content_type, content_length, encoding)
except StopFutureHandlers:
upload_handlers = upload_handlers[:index + 1]
upload_handlers = upload_handlers[: index + 1]
break
for chunk in chunks:
@ -189,9 +193,9 @@ class FileUploadParser(BaseParser):
for index, handler in enumerate(upload_handlers):
file_obj = handler.file_complete(counters[index])
if file_obj is not None:
return DataAndFiles({}, {'file': file_obj})
return DataAndFiles({}, {"file": file_obj})
raise ParseError(self.errors['unhandled'])
raise ParseError(self.errors["unhandled"])
def get_filename(self, stream, media_type, parser_context):
"""
@ -199,17 +203,17 @@ class FileUploadParser(BaseParser):
Then tries to parse Content-Disposition header.
"""
try:
return parser_context['kwargs']['filename']
return parser_context["kwargs"]["filename"]
except KeyError:
pass
try:
meta = parser_context['request'].META
disposition = parse_header(meta['HTTP_CONTENT_DISPOSITION'].encode('utf-8'))
meta = parser_context["request"].META
disposition = parse_header(meta["HTTP_CONTENT_DISPOSITION"].encode("utf-8"))
filename_parm = disposition[1]
if 'filename*' in filename_parm:
if "filename*" in filename_parm:
return self.get_encoded_filename(filename_parm)
return force_text(filename_parm['filename'])
return force_text(filename_parm["filename"])
except (AttributeError, KeyError, ValueError):
pass
@ -218,10 +222,10 @@ class FileUploadParser(BaseParser):
Handle encoded filenames per RFC6266. See also:
https://tools.ietf.org/html/rfc2231#section-4
"""
encoded_filename = force_text(filename_parm['filename*'])
encoded_filename = force_text(filename_parm["filename*"])
try:
charset, lang, filename = encoded_filename.split('\'', 2)
charset, lang, filename = encoded_filename.split("'", 2)
filename = urlparse.unquote(filename)
except (ValueError, LookupError):
filename = force_text(filename_parm['filename'])
filename = force_text(filename_parm["filename"])
return filename

View File

@ -8,7 +8,8 @@ from django.utils import six
from rest_framework import exceptions
SAFE_METHODS = ('GET', 'HEAD', 'OPTIONS')
SAFE_METHODS = ("GET", "HEAD", "OPTIONS")
class OperationHolderMixin:
@ -56,16 +57,14 @@ class AND:
self.op2 = op2
def has_permission(self, request, view):
return (
self.op1.has_permission(request, view) and
self.op2.has_permission(request, view)
return self.op1.has_permission(request, view) and self.op2.has_permission(
request, view
)
def has_object_permission(self, request, view, obj):
return (
self.op1.has_object_permission(request, view, obj) and
self.op2.has_object_permission(request, view, obj)
)
return self.op1.has_object_permission(
request, view, obj
) and self.op2.has_object_permission(request, view, obj)
class OR:
@ -74,16 +73,14 @@ class OR:
self.op2 = op2
def has_permission(self, request, view):
return (
self.op1.has_permission(request, view) or
self.op2.has_permission(request, view)
return self.op1.has_permission(request, view) or self.op2.has_permission(
request, view
)
def has_object_permission(self, request, view, obj):
return (
self.op1.has_object_permission(request, view, obj) or
self.op2.has_object_permission(request, view, obj)
)
return self.op1.has_object_permission(
request, view, obj
) or self.op2.has_object_permission(request, view, obj)
class NOT:
@ -157,9 +154,9 @@ class IsAuthenticatedOrReadOnly(BasePermission):
def has_permission(self, request, view):
return bool(
request.method in SAFE_METHODS or
request.user and
request.user.is_authenticated
request.method in SAFE_METHODS
or request.user
and request.user.is_authenticated
)
@ -179,13 +176,13 @@ class DjangoModelPermissions(BasePermission):
# Override this if you need to also provide 'view' permissions,
# or if you want to provide custom permission codes.
perms_map = {
'GET': [],
'OPTIONS': [],
'HEAD': [],
'POST': ['%(app_label)s.add_%(model_name)s'],
'PUT': ['%(app_label)s.change_%(model_name)s'],
'PATCH': ['%(app_label)s.change_%(model_name)s'],
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
"GET": [],
"OPTIONS": [],
"HEAD": [],
"POST": ["%(app_label)s.add_%(model_name)s"],
"PUT": ["%(app_label)s.change_%(model_name)s"],
"PATCH": ["%(app_label)s.change_%(model_name)s"],
"DELETE": ["%(app_label)s.delete_%(model_name)s"],
}
authenticated_users_only = True
@ -196,8 +193,8 @@ class DjangoModelPermissions(BasePermission):
codes that the user is required to have.
"""
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
"app_label": model_cls._meta.app_label,
"model_name": model_cls._meta.model_name,
}
if method not in self.perms_map:
@ -206,16 +203,19 @@ class DjangoModelPermissions(BasePermission):
return [perm % kwargs for perm in self.perms_map[method]]
def _queryset(self, view):
assert hasattr(view, 'get_queryset') \
or getattr(view, 'queryset', None) is not None, (
'Cannot apply {} on a view that does not set '
'`.queryset` or have a `.get_queryset()` method.'
).format(self.__class__.__name__)
assert (
hasattr(view, "get_queryset") or getattr(view, "queryset", None) is not None
), (
"Cannot apply {} on a view that does not set "
"`.queryset` or have a `.get_queryset()` method."
).format(
self.__class__.__name__
)
if hasattr(view, 'get_queryset'):
if hasattr(view, "get_queryset"):
queryset = view.get_queryset()
assert queryset is not None, (
'{}.get_queryset() returned None'.format(view.__class__.__name__)
assert queryset is not None, "{}.get_queryset() returned None".format(
view.__class__.__name__
)
return queryset
return view.queryset
@ -223,11 +223,12 @@ class DjangoModelPermissions(BasePermission):
def has_permission(self, request, view):
# Workaround to ensure DjangoModelPermissions are not applied
# to the root view when using DefaultRouter.
if getattr(view, '_ignore_model_permissions', False):
if getattr(view, "_ignore_model_permissions", False):
return True
if not request.user or (
not request.user.is_authenticated and self.authenticated_users_only):
not request.user.is_authenticated and self.authenticated_users_only
):
return False
queryset = self._queryset(view)
@ -241,6 +242,7 @@ class DjangoModelPermissionsOrAnonReadOnly(DjangoModelPermissions):
Similar to DjangoModelPermissions, except that anonymous users are
allowed read-only access.
"""
authenticated_users_only = False
@ -255,20 +257,21 @@ class DjangoObjectPermissions(DjangoModelPermissions):
This permission can only be applied against view classes that
provide a `.queryset` attribute.
"""
perms_map = {
'GET': [],
'OPTIONS': [],
'HEAD': [],
'POST': ['%(app_label)s.add_%(model_name)s'],
'PUT': ['%(app_label)s.change_%(model_name)s'],
'PATCH': ['%(app_label)s.change_%(model_name)s'],
'DELETE': ['%(app_label)s.delete_%(model_name)s'],
"GET": [],
"OPTIONS": [],
"HEAD": [],
"POST": ["%(app_label)s.add_%(model_name)s"],
"PUT": ["%(app_label)s.change_%(model_name)s"],
"PATCH": ["%(app_label)s.change_%(model_name)s"],
"DELETE": ["%(app_label)s.delete_%(model_name)s"],
}
def get_required_object_permissions(self, method, model_cls):
kwargs = {
'app_label': model_cls._meta.app_label,
'model_name': model_cls._meta.model_name
"app_label": model_cls._meta.app_label,
"model_name": model_cls._meta.model_name,
}
if method not in self.perms_map:
@ -294,7 +297,7 @@ class DjangoObjectPermissions(DjangoModelPermissions):
# to make another lookup.
raise Http404
read_perms = self.get_required_object_permissions('GET', model_cls)
read_perms = self.get_required_object_permissions("GET", model_cls)
if not user.has_perms(read_perms, obj):
raise Http404

View File

@ -9,14 +9,16 @@ from django.db.models import Manager
from django.db.models.query import QuerySet
from django.urls import NoReverseMatch, Resolver404, get_script_prefix, resolve
from django.utils import six
from django.utils.encoding import (
python_2_unicode_compatible, smart_text, uri_to_iri
)
from django.utils.encoding import python_2_unicode_compatible, smart_text, uri_to_iri
from django.utils.six.moves.urllib import parse as urlparse
from django.utils.translation import ugettext_lazy as _
from rest_framework.fields import (
Field, empty, get_attribute, is_simple_callable, iter_options
Field,
empty,
get_attribute,
is_simple_callable,
iter_options,
)
from rest_framework.reverse import reverse
from rest_framework.settings import api_settings
@ -28,7 +30,7 @@ def method_overridden(method_name, klass, instance):
Determine if a method has been overridden.
"""
method = getattr(klass, method_name)
default_method = getattr(method, '__func__', method) # Python 3 compat
default_method = getattr(method, "__func__", method) # Python 3 compat
return default_method is not getattr(instance, method_name).__func__
@ -52,13 +54,14 @@ class Hyperlink(six.text_type):
We use this for hyperlinked URLs that may render as a named link
in some contexts, or render as a plain URL in others.
"""
def __new__(self, url, obj):
ret = six.text_type.__new__(self, url)
ret.obj = obj
return ret
def __getnewargs__(self):
return(str(self), self.name,)
return (str(self), self.name)
@property
def name(self):
@ -77,6 +80,7 @@ class PKOnlyObject(object):
instance, but still want to return an object with a .pk attribute,
in order to keep the same interface as a regular model instance.
"""
def __init__(self, pk):
self.pk = pk
@ -87,9 +91,19 @@ class PKOnlyObject(object):
# We assume that 'validators' are intended for the child serializer,
# rather than the parent serializer.
MANY_RELATION_KWARGS = (
'read_only', 'write_only', 'required', 'default', 'initial', 'source',
'label', 'help_text', 'style', 'error_messages', 'allow_empty',
'html_cutoff', 'html_cutoff_text'
"read_only",
"write_only",
"required",
"default",
"initial",
"source",
"label",
"help_text",
"style",
"error_messages",
"allow_empty",
"html_cutoff",
"html_cutoff_text",
)
@ -99,34 +113,34 @@ class RelatedField(Field):
html_cutoff_text = None
def __init__(self, **kwargs):
self.queryset = kwargs.pop('queryset', self.queryset)
self.queryset = kwargs.pop("queryset", self.queryset)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
"html_cutoff_text",
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT),
)
if not method_overridden('get_queryset', RelatedField, self):
assert self.queryset is not None or kwargs.get('read_only', None), (
'Relational field must provide a `queryset` argument, '
'override `get_queryset`, or set read_only=`True`.'
if not method_overridden("get_queryset", RelatedField, self):
assert self.queryset is not None or kwargs.get("read_only", None), (
"Relational field must provide a `queryset` argument, "
"override `get_queryset`, or set read_only=`True`."
)
assert not (self.queryset is not None and kwargs.get('read_only', None)), (
'Relational fields should not provide a `queryset` argument, '
'when setting read_only=`True`.'
assert not (self.queryset is not None and kwargs.get("read_only", None)), (
"Relational fields should not provide a `queryset` argument, "
"when setting read_only=`True`."
)
kwargs.pop('many', None)
kwargs.pop('allow_empty', None)
kwargs.pop("many", None)
kwargs.pop("allow_empty", None)
super(RelatedField, self).__init__(**kwargs)
def __new__(cls, *args, **kwargs):
# We override this method in order to automagically create
# `ManyRelatedField` classes instead when `many=True` is set.
if kwargs.pop('many', False):
if kwargs.pop("many", False):
return cls.many_init(*args, **kwargs)
return super(RelatedField, cls).__new__(cls, *args, **kwargs)
@ -147,7 +161,7 @@ class RelatedField(Field):
kwargs['child'] = cls()
return CustomManyRelatedField(*args, **kwargs)
"""
list_kwargs = {'child_relation': cls(*args, **kwargs)}
list_kwargs = {"child_relation": cls(*args, **kwargs)}
for key in kwargs:
if key in MANY_RELATION_KWARGS:
list_kwargs[key] = kwargs[key]
@ -155,7 +169,7 @@ class RelatedField(Field):
def run_validation(self, data=empty):
# We force empty strings to None values for relational fields.
if data == '':
if data == "":
data = None
return super(RelatedField, self).run_validation(data)
@ -201,13 +215,12 @@ class RelatedField(Field):
if cutoff is not None:
queryset = queryset[:cutoff]
return OrderedDict([
(
self.to_representation(item),
self.display_value(item)
)
for item in queryset
])
return OrderedDict(
[
(self.to_representation(item), self.display_value(item))
for item in queryset
]
)
@property
def choices(self):
@ -221,7 +234,7 @@ class RelatedField(Field):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text
cutoff_text=self.html_cutoff_text,
)
def display_value(self, instance):
@ -235,7 +248,7 @@ class StringRelatedField(RelatedField):
"""
def __init__(self, **kwargs):
kwargs['read_only'] = True
kwargs["read_only"] = True
super(StringRelatedField, self).__init__(**kwargs)
def to_representation(self, value):
@ -244,13 +257,13 @@ class StringRelatedField(RelatedField):
class PrimaryKeyRelatedField(RelatedField):
default_error_messages = {
'required': _('This field is required.'),
'does_not_exist': _('Invalid pk "{pk_value}" - object does not exist.'),
'incorrect_type': _('Incorrect type. Expected pk value, received {data_type}.'),
"required": _("This field is required."),
"does_not_exist": _('Invalid pk "{pk_value}" - object does not exist.'),
"incorrect_type": _("Incorrect type. Expected pk value, received {data_type}."),
}
def __init__(self, **kwargs):
self.pk_field = kwargs.pop('pk_field', None)
self.pk_field = kwargs.pop("pk_field", None)
super(PrimaryKeyRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self):
@ -262,9 +275,9 @@ class PrimaryKeyRelatedField(RelatedField):
try:
return self.get_queryset().get(pk=data)
except ObjectDoesNotExist:
self.fail('does_not_exist', pk_value=data)
self.fail("does_not_exist", pk_value=data)
except (TypeError, ValueError):
self.fail('incorrect_type', data_type=type(data).__name__)
self.fail("incorrect_type", data_type=type(data).__name__)
def to_representation(self, value):
if self.pk_field is not None:
@ -273,24 +286,26 @@ class PrimaryKeyRelatedField(RelatedField):
class HyperlinkedRelatedField(RelatedField):
lookup_field = 'pk'
lookup_field = "pk"
view_name = None
default_error_messages = {
'required': _('This field is required.'),
'no_match': _('Invalid hyperlink - No URL match.'),
'incorrect_match': _('Invalid hyperlink - Incorrect URL match.'),
'does_not_exist': _('Invalid hyperlink - Object does not exist.'),
'incorrect_type': _('Incorrect type. Expected URL string, received {data_type}.'),
"required": _("This field is required."),
"no_match": _("Invalid hyperlink - No URL match."),
"incorrect_match": _("Invalid hyperlink - Incorrect URL match."),
"does_not_exist": _("Invalid hyperlink - Object does not exist."),
"incorrect_type": _(
"Incorrect type. Expected URL string, received {data_type}."
),
}
def __init__(self, view_name=None, **kwargs):
if view_name is not None:
self.view_name = view_name
assert self.view_name is not None, 'The `view_name` argument is required.'
self.lookup_field = kwargs.pop('lookup_field', self.lookup_field)
self.lookup_url_kwarg = kwargs.pop('lookup_url_kwarg', self.lookup_field)
self.format = kwargs.pop('format', None)
assert self.view_name is not None, "The `view_name` argument is required."
self.lookup_field = kwargs.pop("lookup_field", self.lookup_field)
self.lookup_url_kwarg = kwargs.pop("lookup_url_kwarg", self.lookup_field)
self.format = kwargs.pop("format", None)
# We include this simply for dependency injection in tests.
# We can't add it as a class attributes or it would expect an
@ -300,7 +315,7 @@ class HyperlinkedRelatedField(RelatedField):
super(HyperlinkedRelatedField, self).__init__(**kwargs)
def use_pk_only_optimization(self):
return self.lookup_field == 'pk'
return self.lookup_field == "pk"
def get_object(self, view_name, view_args, view_kwargs):
"""
@ -330,7 +345,7 @@ class HyperlinkedRelatedField(RelatedField):
attributes are not configured to correctly match the URL conf.
"""
# Unsaved objects will not yet have a valid URL.
if hasattr(obj, 'pk') and obj.pk in (None, ''):
if hasattr(obj, "pk") and obj.pk in (None, ""):
return None
lookup_value = getattr(obj, self.lookup_field)
@ -338,25 +353,25 @@ class HyperlinkedRelatedField(RelatedField):
return self.reverse(view_name, kwargs=kwargs, request=request, format=format)
def to_internal_value(self, data):
request = self.context.get('request', None)
request = self.context.get("request", None)
try:
http_prefix = data.startswith(('http:', 'https:'))
http_prefix = data.startswith(("http:", "https:"))
except AttributeError:
self.fail('incorrect_type', data_type=type(data).__name__)
self.fail("incorrect_type", data_type=type(data).__name__)
if http_prefix:
# If needed convert absolute URLs to relative path
data = urlparse.urlparse(data).path
prefix = get_script_prefix()
if data.startswith(prefix):
data = '/' + data[len(prefix):]
data = "/" + data[len(prefix) :]
data = uri_to_iri(data)
try:
match = resolve(data)
except Resolver404:
self.fail('no_match')
self.fail("no_match")
try:
expected_viewname = request.versioning_scheme.get_versioned_viewname(
@ -366,22 +381,22 @@ class HyperlinkedRelatedField(RelatedField):
expected_viewname = self.view_name
if match.view_name != expected_viewname:
self.fail('incorrect_match')
self.fail("incorrect_match")
try:
return self.get_object(match.view_name, match.args, match.kwargs)
except (ObjectDoesNotExist, ObjectValueError, ObjectTypeError):
self.fail('does_not_exist')
self.fail("does_not_exist")
def to_representation(self, value):
assert 'request' in self.context, (
assert "request" in self.context, (
"`%s` requires the request in the serializer"
" context. Add `context={'request': request}` when instantiating "
"the serializer." % self.__class__.__name__
)
request = self.context['request']
format = self.context.get('format', None)
request = self.context["request"]
format = self.context.get("format", None)
# By default use whatever format is given for the current context
# unless the target is a different type to the source.
@ -400,13 +415,13 @@ class HyperlinkedRelatedField(RelatedField):
url = self.get_url(value, self.view_name, request, format)
except NoReverseMatch:
msg = (
'Could not resolve URL for hyperlinked relationship using '
"Could not resolve URL for hyperlinked relationship using "
'view name "%s". You may have failed to include the related '
'model in your API, or incorrectly configured the '
'`lookup_field` attribute on this field.'
"model in your API, or incorrectly configured the "
"`lookup_field` attribute on this field."
)
if value in ('', None):
value_string = {'': 'the empty string', None: 'None'}[value]
if value in ("", None):
value_string = {"": "the empty string", None: "None"}[value]
msg += (
" WARNING: The value of the field on the model instance "
"was %s, which may be why it didn't match any "
@ -429,9 +444,9 @@ class HyperlinkedIdentityField(HyperlinkedRelatedField):
"""
def __init__(self, view_name=None, **kwargs):
assert view_name is not None, 'The `view_name` argument is required.'
kwargs['read_only'] = True
kwargs['source'] = '*'
assert view_name is not None, "The `view_name` argument is required."
kwargs["read_only"] = True
kwargs["source"] = "*"
super(HyperlinkedIdentityField, self).__init__(view_name, **kwargs)
def use_pk_only_optimization(self):
@ -445,13 +460,14 @@ class SlugRelatedField(RelatedField):
A read-write field that represents the target of the relationship
by a unique 'slug' attribute.
"""
default_error_messages = {
'does_not_exist': _('Object with {slug_name}={value} does not exist.'),
'invalid': _('Invalid value.'),
"does_not_exist": _("Object with {slug_name}={value} does not exist."),
"invalid": _("Invalid value."),
}
def __init__(self, slug_field=None, **kwargs):
assert slug_field is not None, 'The `slug_field` argument is required.'
assert slug_field is not None, "The `slug_field` argument is required."
self.slug_field = slug_field
super(SlugRelatedField, self).__init__(**kwargs)
@ -459,9 +475,11 @@ class SlugRelatedField(RelatedField):
try:
return self.get_queryset().get(**{self.slug_field: data})
except ObjectDoesNotExist:
self.fail('does_not_exist', slug_name=self.slug_field, value=smart_text(data))
self.fail(
"does_not_exist", slug_name=self.slug_field, value=smart_text(data)
)
except (TypeError, ValueError):
self.fail('invalid')
self.fail("invalid")
def to_representation(self, obj):
return getattr(obj, self.slug_field)
@ -479,31 +497,32 @@ class ManyRelatedField(Field):
You shouldn't generally need to be using this class directly yourself,
and should instead simply set 'many=True' on the relationship.
"""
initial = []
default_empty_html = []
default_error_messages = {
'not_a_list': _('Expected a list of items but got type "{input_type}".'),
'empty': _('This list may not be empty.')
"not_a_list": _('Expected a list of items but got type "{input_type}".'),
"empty": _("This list may not be empty."),
}
html_cutoff = None
html_cutoff_text = None
def __init__(self, child_relation=None, *args, **kwargs):
self.child_relation = child_relation
self.allow_empty = kwargs.pop('allow_empty', True)
self.allow_empty = kwargs.pop("allow_empty", True)
cutoff_from_settings = api_settings.HTML_SELECT_CUTOFF
if cutoff_from_settings is not None:
cutoff_from_settings = int(cutoff_from_settings)
self.html_cutoff = kwargs.pop('html_cutoff', cutoff_from_settings)
self.html_cutoff = kwargs.pop("html_cutoff", cutoff_from_settings)
self.html_cutoff_text = kwargs.pop(
'html_cutoff_text',
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT)
"html_cutoff_text",
self.html_cutoff_text or _(api_settings.HTML_SELECT_CUTOFF_TEXT),
)
assert child_relation is not None, '`child_relation` is a required argument.'
assert child_relation is not None, "`child_relation` is a required argument."
super(ManyRelatedField, self).__init__(*args, **kwargs)
self.child_relation.bind(field_name='', parent=self)
self.child_relation.bind(field_name="", parent=self)
def get_value(self, dictionary):
# We override the default field access in order to support
@ -511,36 +530,30 @@ class ManyRelatedField(Field):
if html.is_html_input(dictionary):
# Don't return [] if the update is partial
if self.field_name not in dictionary:
if getattr(self.root, 'partial', False):
if getattr(self.root, "partial", False):
return empty
return dictionary.getlist(self.field_name)
return dictionary.get(self.field_name, empty)
def to_internal_value(self, data):
if isinstance(data, six.text_type) or not hasattr(data, '__iter__'):
self.fail('not_a_list', input_type=type(data).__name__)
if isinstance(data, six.text_type) or not hasattr(data, "__iter__"):
self.fail("not_a_list", input_type=type(data).__name__)
if not self.allow_empty and len(data) == 0:
self.fail('empty')
self.fail("empty")
return [
self.child_relation.to_internal_value(item)
for item in data
]
return [self.child_relation.to_internal_value(item) for item in data]
def get_attribute(self, instance):
# Can't have any relationships if not created
if hasattr(instance, 'pk') and instance.pk is None:
if hasattr(instance, "pk") and instance.pk is None:
return []
relationship = get_attribute(instance, self.source_attrs)
return relationship.all() if hasattr(relationship, 'all') else relationship
return relationship.all() if hasattr(relationship, "all") else relationship
def to_representation(self, iterable):
return [
self.child_relation.to_representation(value)
for value in iterable
]
return [self.child_relation.to_representation(value) for value in iterable]
def get_choices(self, cutoff=None):
return self.child_relation.get_choices(cutoff)
@ -557,5 +570,5 @@ class ManyRelatedField(Field):
return iter_options(
self.get_choices(cutoff=self.html_cutoff),
cutoff=self.html_cutoff,
cutoff_text=self.html_cutoff_text
cutoff_text=self.html_cutoff_text,
)

File diff suppressed because it is too large Load Diff

View File

@ -30,8 +30,10 @@ def is_form_media_type(media_type):
Return True if the media type is a valid form media type.
"""
base_media_type, params = parse_header(media_type.encode(HTTP_HEADER_ENCODING))
return (base_media_type == 'application/x-www-form-urlencoded' or
base_media_type == 'multipart/form-data')
return (
base_media_type == "application/x-www-form-urlencoded"
or base_media_type == "multipart/form-data"
)
class override_method(object):
@ -49,12 +51,12 @@ class override_method(object):
self.view = view
self.request = request
self.method = method
self.action = getattr(view, 'action', None)
self.action = getattr(view, "action", None)
def __enter__(self):
self.view.request = clone_request(self.request, self.method)
# For viewsets we also set the `.action` attribute.
action_map = getattr(self.view, 'action_map', {})
action_map = getattr(self.view, "action_map", {})
self.view.action = action_map.get(self.method.lower())
return self.view.request
@ -86,6 +88,7 @@ class Empty(object):
Placeholder for unset attributes.
Cannot use `None`, as that may be a valid value.
"""
pass
@ -98,30 +101,32 @@ def clone_request(request, method):
Internal helper method to clone a request, replacing with a different
HTTP method. Used for checking permissions against other methods.
"""
ret = Request(request=request._request,
parsers=request.parsers,
authenticators=request.authenticators,
negotiator=request.negotiator,
parser_context=request.parser_context)
ret = Request(
request=request._request,
parsers=request.parsers,
authenticators=request.authenticators,
negotiator=request.negotiator,
parser_context=request.parser_context,
)
ret._data = request._data
ret._files = request._files
ret._full_data = request._full_data
ret._content_type = request._content_type
ret._stream = request._stream
ret.method = method
if hasattr(request, '_user'):
if hasattr(request, "_user"):
ret._user = request._user
if hasattr(request, '_auth'):
if hasattr(request, "_auth"):
ret._auth = request._auth
if hasattr(request, '_authenticator'):
if hasattr(request, "_authenticator"):
ret._authenticator = request._authenticator
if hasattr(request, 'accepted_renderer'):
if hasattr(request, "accepted_renderer"):
ret.accepted_renderer = request.accepted_renderer
if hasattr(request, 'accepted_media_type'):
if hasattr(request, "accepted_media_type"):
ret.accepted_media_type = request.accepted_media_type
if hasattr(request, 'version'):
if hasattr(request, "version"):
ret.version = request.version
if hasattr(request, 'versioning_scheme'):
if hasattr(request, "versioning_scheme"):
ret.versioning_scheme = request.versioning_scheme
return ret
@ -152,12 +157,19 @@ class Request(object):
authenticating the request's user.
"""
def __init__(self, request, parsers=None, authenticators=None,
negotiator=None, parser_context=None):
def __init__(
self,
request,
parsers=None,
authenticators=None,
negotiator=None,
parser_context=None,
):
assert isinstance(request, HttpRequest), (
'The `request` argument must be an instance of '
'`django.http.HttpRequest`, not `{}.{}`.'
.format(request.__class__.__module__, request.__class__.__name__)
"The `request` argument must be an instance of "
"`django.http.HttpRequest`, not `{}.{}`.".format(
request.__class__.__module__, request.__class__.__name__
)
)
self._request = request
@ -173,11 +185,11 @@ class Request(object):
if self.parser_context is None:
self.parser_context = {}
self.parser_context['request'] = self
self.parser_context['encoding'] = request.encoding or settings.DEFAULT_CHARSET
self.parser_context["request"] = self
self.parser_context["encoding"] = request.encoding or settings.DEFAULT_CHARSET
force_user = getattr(request, '_force_auth_user', None)
force_token = getattr(request, '_force_auth_token', None)
force_user = getattr(request, "_force_auth_user", None)
force_token = getattr(request, "_force_auth_token", None)
if force_user is not None or force_token is not None:
forced_auth = ForcedAuthentication(force_user, force_token)
self.authenticators = (forced_auth,)
@ -188,14 +200,14 @@ class Request(object):
@property
def content_type(self):
meta = self._request.META
return meta.get('CONTENT_TYPE', meta.get('HTTP_CONTENT_TYPE', ''))
return meta.get("CONTENT_TYPE", meta.get("HTTP_CONTENT_TYPE", ""))
@property
def stream(self):
"""
Returns an object that may be used to stream the request content.
"""
if not _hasattr(self, '_stream'):
if not _hasattr(self, "_stream"):
self._load_stream()
return self._stream
@ -208,7 +220,7 @@ class Request(object):
@property
def data(self):
if not _hasattr(self, '_full_data'):
if not _hasattr(self, "_full_data"):
self._load_data_and_files()
return self._full_data
@ -218,7 +230,7 @@ class Request(object):
Returns the user associated with the current request, as authenticated
by the authentication classes provided to the request.
"""
if not hasattr(self, '_user'):
if not hasattr(self, "_user"):
with wrap_attributeerrors():
self._authenticate()
return self._user
@ -242,7 +254,7 @@ class Request(object):
Returns any non-user authentication information associated with the
request, such as an authentication token.
"""
if not hasattr(self, '_auth'):
if not hasattr(self, "_auth"):
with wrap_attributeerrors():
self._authenticate()
return self._auth
@ -262,7 +274,7 @@ class Request(object):
Return the instance of the authentication instance class that was used
to authenticate the request, or `None`.
"""
if not hasattr(self, '_authenticator'):
if not hasattr(self, "_authenticator"):
with wrap_attributeerrors():
self._authenticate()
return self._authenticator
@ -271,7 +283,7 @@ class Request(object):
"""
Parses the request content into `self.data`.
"""
if not _hasattr(self, '_data'):
if not _hasattr(self, "_data"):
self._data, self._files = self._parse()
if self._files:
self._full_data = self._data.copy()
@ -292,7 +304,7 @@ class Request(object):
meta = self._request.META
try:
content_length = int(
meta.get('CONTENT_LENGTH', meta.get('HTTP_CONTENT_LENGTH', 0))
meta.get("CONTENT_LENGTH", meta.get("HTTP_CONTENT_LENGTH", 0))
)
except (ValueError, TypeError):
content_length = 0
@ -308,10 +320,7 @@ class Request(object):
"""
Return True if this requests supports parsing form data.
"""
form_media = (
'application/x-www-form-urlencoded',
'multipart/form-data'
)
form_media = ("application/x-www-form-urlencoded", "multipart/form-data")
return any([parser.media_type in form_media for parser in self.parsers])
def _parse(self):
@ -324,7 +333,7 @@ class Request(object):
try:
stream = self.stream
except RawPostDataException:
if not hasattr(self._request, '_post'):
if not hasattr(self._request, "_post"):
raise
# If request.POST has been accessed in middleware, and a method='POST'
# request was made with 'multipart/form-data', then the request stream
@ -335,7 +344,7 @@ class Request(object):
if stream is None or media_type is None:
if media_type and is_form_media_type(media_type):
empty_data = QueryDict('', encoding=self._request._encoding)
empty_data = QueryDict("", encoding=self._request._encoding)
else:
empty_data = {}
empty_files = MultiValueDict()
@ -353,7 +362,7 @@ class Request(object):
# re-raise. Ensures we don't simply repeat the error when
# attempting to render the browsable renderer response, or when
# logging the request or similar.
self._data = QueryDict('', encoding=self._request._encoding)
self._data = QueryDict("", encoding=self._request._encoding)
self._files = MultiValueDict()
self._full_data = self._data
raise
@ -416,33 +425,33 @@ class Request(object):
@property
def DATA(self):
raise NotImplementedError(
'`request.DATA` has been deprecated in favor of `request.data` '
'since version 3.0, and has been fully removed as of version 3.2.'
"`request.DATA` has been deprecated in favor of `request.data` "
"since version 3.0, and has been fully removed as of version 3.2."
)
@property
def POST(self):
# Ensure that request.POST uses our request parsing.
if not _hasattr(self, '_data'):
if not _hasattr(self, "_data"):
self._load_data_and_files()
if is_form_media_type(self.content_type):
return self._data
return QueryDict('', encoding=self._request._encoding)
return QueryDict("", encoding=self._request._encoding)
@property
def FILES(self):
# Leave this one alone for backwards compat with Django's request.FILES
# Different from the other two cases, which are not valid property
# names on the WSGIRequest class.
if not _hasattr(self, '_files'):
if not _hasattr(self, "_files"):
self._load_data_and_files()
return self._files
@property
def QUERY_PARAMS(self):
raise NotImplementedError(
'`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` '
'since version 3.0, and has been fully removed as of version 3.2.'
"`request.QUERY_PARAMS` has been deprecated in favor of `request.query_params` "
"since version 3.0, and has been fully removed as of version 3.2."
)
def force_plaintext_errors(self, value):

View File

@ -19,9 +19,15 @@ class Response(SimpleTemplateResponse):
arbitrary media types.
"""
def __init__(self, data=None, status=None,
template_name=None, headers=None,
exception=False, content_type=None):
def __init__(
self,
data=None,
status=None,
template_name=None,
headers=None,
exception=False,
content_type=None,
):
"""
Alters the init arguments slightly.
For example, drop 'template_name', and instead use 'data'.
@ -33,9 +39,9 @@ class Response(SimpleTemplateResponse):
if isinstance(data, Serializer):
msg = (
'You passed a Serializer instance as data, but '
'probably meant to pass serialized `.data` or '
'`.error`. representation.'
"You passed a Serializer instance as data, but "
"probably meant to pass serialized `.data` or "
"`.error`. representation."
)
raise AssertionError(msg)
@ -50,14 +56,14 @@ class Response(SimpleTemplateResponse):
@property
def rendered_content(self):
renderer = getattr(self, 'accepted_renderer', None)
accepted_media_type = getattr(self, 'accepted_media_type', None)
context = getattr(self, 'renderer_context', None)
renderer = getattr(self, "accepted_renderer", None)
accepted_media_type = getattr(self, "accepted_media_type", None)
context = getattr(self, "renderer_context", None)
assert renderer, ".accepted_renderer not set on Response"
assert accepted_media_type, ".accepted_media_type not set on Response"
assert context is not None, ".renderer_context not set on Response"
context['response'] = self
context["response"] = self
media_type = renderer.media_type
charset = renderer.charset
@ -67,18 +73,17 @@ class Response(SimpleTemplateResponse):
content_type = "{0}; charset={1}".format(media_type, charset)
elif content_type is None:
content_type = media_type
self['Content-Type'] = content_type
self["Content-Type"] = content_type
ret = renderer.render(self.data, accepted_media_type, context)
if isinstance(ret, six.text_type):
assert charset, (
'renderer returned unicode, and did not specify '
'a charset value.'
"renderer returned unicode, and did not specify " "a charset value."
)
return bytes(ret.encode(charset))
if not ret:
del self['Content-Type']
del self["Content-Type"]
return ret
@ -88,7 +93,7 @@ class Response(SimpleTemplateResponse):
Returns reason text corresponding to our HTTP response status code.
Provided for convenience.
"""
return responses.get(self.status_code, '')
return responses.get(self.status_code, "")
def __getstate__(self):
"""
@ -96,10 +101,15 @@ class Response(SimpleTemplateResponse):
"""
state = super(Response, self).__getstate__()
for key in (
'accepted_renderer', 'renderer_context', 'resolver_match',
'client', 'request', 'json', 'wsgi_request'
"accepted_renderer",
"renderer_context",
"resolver_match",
"client",
"request",
"json",
"wsgi_request",
):
if key in state:
del state[key]
state['_closable_objects'] = []
state["_closable_objects"] = []
return state

View File

@ -3,8 +3,7 @@ Provide urlresolver functions that return fully qualified URLs or view names
"""
from __future__ import unicode_literals
from django.urls import NoReverseMatch
from django.urls import reverse as django_reverse
from django.urls import NoReverseMatch, reverse as django_reverse
from django.utils import six
from django.utils.functional import lazy
@ -20,9 +19,7 @@ def preserve_builtin_query_params(url, request=None):
if request is None:
return url
overrides = [
api_settings.URL_FORMAT_OVERRIDE,
]
overrides = [api_settings.URL_FORMAT_OVERRIDE]
for param in overrides:
if param and (param in request.GET):
@ -38,7 +35,7 @@ def reverse(viewname, args=None, kwargs=None, request=None, format=None, **extra
to the versioning scheme instance, so that the resulting URL
can be modified if needed.
"""
scheme = getattr(request, 'versioning_scheme', None)
scheme = getattr(request, "versioning_scheme", None)
if scheme is not None:
try:
url = scheme.reverse(viewname, args, kwargs, request, format, **extra)
@ -59,7 +56,7 @@ def _reverse(viewname, args=None, kwargs=None, request=None, format=None, **extr
"""
if format is not None:
kwargs = kwargs or {}
kwargs['format'] = format
kwargs["format"] = format
url = django_reverse(viewname, args=args, kwargs=kwargs, **extra)
if request:
return request.build_absolute_uri(url)

View File

@ -25,9 +25,7 @@ from django.urls import NoReverseMatch
from django.utils import six
from django.utils.deprecation import RenameMethodsBase
from rest_framework import (
RemovedInDRF310Warning, RemovedInDRF311Warning, views
)
from rest_framework import RemovedInDRF310Warning, RemovedInDRF311Warning, views
from rest_framework.response import Response
from rest_framework.reverse import reverse
from rest_framework.schemas import SchemaGenerator
@ -35,8 +33,9 @@ from rest_framework.schemas.views import SchemaView
from rest_framework.settings import api_settings
from rest_framework.urlpatterns import format_suffix_patterns
Route = namedtuple('Route', ['url', 'mapping', 'name', 'detail', 'initkwargs'])
DynamicRoute = namedtuple('DynamicRoute', ['url', 'name', 'detail', 'initkwargs'])
Route = namedtuple("Route", ["url", "mapping", "name", "detail", "initkwargs"])
DynamicRoute = namedtuple("DynamicRoute", ["url", "name", "detail", "initkwargs"])
class DynamicDetailRoute(object):
@ -45,7 +44,8 @@ class DynamicDetailRoute(object):
"`DynamicDetailRoute` is deprecated and will be removed in 3.10 "
"in favor of `DynamicRoute`, which accepts a `detail` boolean. Use "
"`DynamicRoute(url, name, True, initkwargs)` instead.",
RemovedInDRF310Warning, stacklevel=2
RemovedInDRF310Warning,
stacklevel=2,
)
return DynamicRoute(url, name, True, initkwargs)
@ -56,7 +56,8 @@ class DynamicListRoute(object):
"`DynamicListRoute` is deprecated and will be removed in 3.10 in "
"favor of `DynamicRoute`, which accepts a `detail` boolean. Use "
"`DynamicRoute(url, name, False, initkwargs)` instead.",
RemovedInDRF310Warning, stacklevel=2
RemovedInDRF310Warning,
stacklevel=2,
)
return DynamicRoute(url, name, False, initkwargs)
@ -65,8 +66,8 @@ def escape_curly_brackets(url_path):
"""
Double brackets in regex of url_path for escape string formatting
"""
if ('{' and '}') in url_path:
url_path = url_path.replace('{', '{{').replace('}', '}}')
if ("{" and "}") in url_path:
url_path = url_path.replace("{", "{{").replace("}", "}}")
return url_path
@ -79,7 +80,7 @@ def flatten(list_of_lists):
class RenameRouterMethods(RenameMethodsBase):
renamed_methods = (
('get_default_base_name', 'get_default_basename', RemovedInDRF311Warning),
("get_default_base_name", "get_default_basename", RemovedInDRF311Warning),
)
@ -92,8 +93,9 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
msg = "The `base_name` argument is pending deprecation in favor of `basename`."
warnings.warn(msg, RemovedInDRF311Warning, 2)
assert not (basename and base_name), (
"Do not provide both the `basename` and `base_name` arguments.")
assert not (
basename and base_name
), "Do not provide both the `basename` and `base_name` arguments."
if basename is None:
basename = base_name
@ -103,7 +105,7 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
self.registry.append((prefix, viewset, basename))
# invalidate the urls cache
if hasattr(self, '_urls'):
if hasattr(self, "_urls"):
del self._urls
def get_default_basename(self, viewset):
@ -111,17 +113,17 @@ class BaseRouter(six.with_metaclass(RenameRouterMethods)):
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
raise NotImplementedError('get_default_basename must be overridden')
raise NotImplementedError("get_default_basename must be overridden")
def get_urls(self):
"""
Return a list of URL patterns, given the registered viewsets.
"""
raise NotImplementedError('get_urls must be overridden')
raise NotImplementedError("get_urls must be overridden")
@property
def urls(self):
if not hasattr(self, '_urls'):
if not hasattr(self, "_urls"):
self._urls = self.get_urls()
return self._urls
@ -131,48 +133,45 @@ class SimpleRouter(BaseRouter):
routes = [
# List route.
Route(
url=r'^{prefix}{trailing_slash}$',
mapping={
'get': 'list',
'post': 'create'
},
name='{basename}-list',
url=r"^{prefix}{trailing_slash}$",
mapping={"get": "list", "post": "create"},
name="{basename}-list",
detail=False,
initkwargs={'suffix': 'List'}
initkwargs={"suffix": "List"},
),
# Dynamically generated list routes. Generated using
# @action(detail=False) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
url=r"^{prefix}/{url_path}{trailing_slash}$",
name="{basename}-{url_name}",
detail=False,
initkwargs={}
initkwargs={},
),
# Detail route.
Route(
url=r'^{prefix}/{lookup}{trailing_slash}$',
url=r"^{prefix}/{lookup}{trailing_slash}$",
mapping={
'get': 'retrieve',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy'
"get": "retrieve",
"put": "update",
"patch": "partial_update",
"delete": "destroy",
},
name='{basename}-detail',
name="{basename}-detail",
detail=True,
initkwargs={'suffix': 'Instance'}
initkwargs={"suffix": "Instance"},
),
# Dynamically generated detail routes. Generated using
# @action(detail=True) decorator on methods of the viewset.
DynamicRoute(
url=r'^{prefix}/{lookup}/{url_path}{trailing_slash}$',
name='{basename}-{url_name}',
url=r"^{prefix}/{lookup}/{url_path}{trailing_slash}$",
name="{basename}-{url_name}",
detail=True,
initkwargs={}
initkwargs={},
),
]
def __init__(self, trailing_slash=True):
self.trailing_slash = '/' if trailing_slash else ''
self.trailing_slash = "/" if trailing_slash else ""
super(SimpleRouter, self).__init__()
def get_default_basename(self, viewset):
@ -180,11 +179,13 @@ class SimpleRouter(BaseRouter):
If `basename` is not specified, attempt to automatically determine
it from the viewset.
"""
queryset = getattr(viewset, 'queryset', None)
queryset = getattr(viewset, "queryset", None)
assert queryset is not None, '`basename` argument not specified, and could ' \
'not automatically determine the name from the viewset, as ' \
'it does not have a `.queryset` attribute.'
assert queryset is not None, (
"`basename` argument not specified, and could "
"not automatically determine the name from the viewset, as "
"it does not have a `.queryset` attribute."
)
return queryset.model._meta.object_name.lower()
@ -196,18 +197,29 @@ class SimpleRouter(BaseRouter):
"""
# converting to list as iterables are good for one pass, known host needs to be checked again and again for
# different functions.
known_actions = list(flatten([route.mapping.values() for route in self.routes if isinstance(route, Route)]))
known_actions = list(
flatten(
[
route.mapping.values()
for route in self.routes
if isinstance(route, Route)
]
)
)
extra_actions = viewset.get_extra_actions()
# checking action names against the known actions list
not_allowed = [
action.__name__ for action in extra_actions
action.__name__
for action in extra_actions
if action.__name__ in known_actions
]
if not_allowed:
msg = ('Cannot use the @action decorator on the following '
'methods, as they are existing routes: %s')
raise ImproperlyConfigured(msg % ', '.join(not_allowed))
msg = (
"Cannot use the @action decorator on the following "
"methods, as they are existing routes: %s"
)
raise ImproperlyConfigured(msg % ", ".join(not_allowed))
# partition detail and list actions
detail_actions = [action for action in extra_actions if action.detail]
@ -216,9 +228,13 @@ class SimpleRouter(BaseRouter):
routes = []
for route in self.routes:
if isinstance(route, DynamicRoute) and route.detail:
routes += [self._get_dynamic_route(route, action) for action in detail_actions]
routes += [
self._get_dynamic_route(route, action) for action in detail_actions
]
elif isinstance(route, DynamicRoute) and not route.detail:
routes += [self._get_dynamic_route(route, action) for action in list_actions]
routes += [
self._get_dynamic_route(route, action) for action in list_actions
]
else:
routes.append(route)
@ -231,9 +247,9 @@ class SimpleRouter(BaseRouter):
url_path = escape_curly_brackets(action.url_path)
return Route(
url=route.url.replace('{url_path}', url_path),
url=route.url.replace("{url_path}", url_path),
mapping=action.mapping,
name=route.name.replace('{url_name}', action.url_name),
name=route.name.replace("{url_name}", action.url_name),
detail=route.detail,
initkwargs=initkwargs,
)
@ -250,7 +266,7 @@ class SimpleRouter(BaseRouter):
bound_methods[method] = action
return bound_methods
def get_lookup_regex(self, viewset, lookup_prefix=''):
def get_lookup_regex(self, viewset, lookup_prefix=""):
"""
Given a viewset, return the portion of URL regex that is used
to match against a single instance.
@ -261,16 +277,16 @@ class SimpleRouter(BaseRouter):
https://github.com/alanjds/drf-nested-routers
"""
base_regex = '(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})'
base_regex = "(?P<{lookup_prefix}{lookup_url_kwarg}>{lookup_value})"
# Use `pk` as default field, unset set. Default regex should not
# consume `.json` style suffixes and should break at '/' boundaries.
lookup_field = getattr(viewset, 'lookup_field', 'pk')
lookup_url_kwarg = getattr(viewset, 'lookup_url_kwarg', None) or lookup_field
lookup_value = getattr(viewset, 'lookup_value_regex', '[^/.]+')
lookup_field = getattr(viewset, "lookup_field", "pk")
lookup_url_kwarg = getattr(viewset, "lookup_url_kwarg", None) or lookup_field
lookup_value = getattr(viewset, "lookup_value_regex", "[^/.]+")
return base_regex.format(
lookup_prefix=lookup_prefix,
lookup_url_kwarg=lookup_url_kwarg,
lookup_value=lookup_value
lookup_value=lookup_value,
)
def get_urls(self):
@ -292,23 +308,18 @@ class SimpleRouter(BaseRouter):
# Build the url pattern
regex = route.url.format(
prefix=prefix,
lookup=lookup,
trailing_slash=self.trailing_slash
prefix=prefix, lookup=lookup, trailing_slash=self.trailing_slash
)
# If there is no prefix, the first part of the url is probably
# controlled by project's urls.py and the router is in an app,
# so a slash in the beginning will (A) cause Django to give
# warnings and (B) generate URLS that will require using '//'.
if not prefix and regex[:2] == '^/':
regex = '^' + regex[2:]
if not prefix and regex[:2] == "^/":
regex = "^" + regex[2:]
initkwargs = route.initkwargs.copy()
initkwargs.update({
'basename': basename,
'detail': route.detail,
})
initkwargs.update({"basename": basename, "detail": route.detail})
view = viewset.as_view(mapping, **initkwargs)
name = route.name.format(basename=basename)
@ -321,6 +332,7 @@ class APIRootView(views.APIView):
"""
The default basic root view for DefaultRouter
"""
_ignore_model_permissions = True
schema = None # exclude from schema
api_root_dict = None
@ -331,14 +343,14 @@ class APIRootView(views.APIView):
namespace = request.resolver_match.namespace
for key, url_name in self.api_root_dict.items():
if namespace:
url_name = namespace + ':' + url_name
url_name = namespace + ":" + url_name
try:
ret[key] = reverse(
url_name,
args=args,
kwargs=kwargs,
request=request,
format=kwargs.get('format', None)
format=kwargs.get("format", None),
)
except NoReverseMatch:
# Don't bail out if eg. no list routes exist, only detail routes.
@ -352,17 +364,18 @@ class DefaultRouter(SimpleRouter):
The default router extends the SimpleRouter, but also adds in a default
API root view, and adds format suffix patterns to the URLs.
"""
include_root_view = True
include_format_suffixes = True
root_view_name = 'api-root'
root_view_name = "api-root"
default_schema_renderers = None
APIRootView = APIRootView
APISchemaView = SchemaView
SchemaGenerator = SchemaGenerator
def __init__(self, *args, **kwargs):
if 'root_renderers' in kwargs:
self.root_renderers = kwargs.pop('root_renderers')
if "root_renderers" in kwargs:
self.root_renderers = kwargs.pop("root_renderers")
else:
self.root_renderers = list(api_settings.DEFAULT_RENDERER_CLASSES)
super(DefaultRouter, self).__init__(*args, **kwargs)
@ -387,7 +400,7 @@ class DefaultRouter(SimpleRouter):
if self.include_root_view:
view = self.get_api_root_view(api_urls=urls)
root_url = url(r'^$', view, name=self.root_view_name)
root_url = url(r"^$", view, name=self.root_view_name)
urls.append(root_url)
if self.include_format_suffixes:

View File

@ -27,18 +27,29 @@ from .inspectors import AutoSchema, DefaultSchema, ManualSchema # noqa
def get_schema_view(
title=None, url=None, description=None, urlconf=None, renderer_classes=None,
public=False, patterns=None, generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES):
title=None,
url=None,
description=None,
urlconf=None,
renderer_classes=None,
public=False,
patterns=None,
generator_class=SchemaGenerator,
authentication_classes=api_settings.DEFAULT_AUTHENTICATION_CLASSES,
permission_classes=api_settings.DEFAULT_PERMISSION_CLASSES,
):
"""
Return a schema view.
"""
# Avoid import cycle on APIView
from .views import SchemaView
generator = generator_class(
title=title, url=url, description=description,
urlconf=urlconf, patterns=patterns,
title=title,
url=url,
description=description,
urlconf=urlconf,
patterns=patterns,
)
return SchemaView.as_view(
renderer_classes=renderer_classes,

View File

@ -15,7 +15,11 @@ from django.utils import six
from rest_framework import exceptions
from rest_framework.compat import (
URLPattern, URLResolver, coreapi, coreschema, get_original_route
URLPattern,
URLResolver,
coreapi,
coreschema,
get_original_route,
)
from rest_framework.request import clone_request
from rest_framework.settings import api_settings
@ -25,7 +29,7 @@ from .utils import is_list_view
def common_path(paths):
split_paths = [path.strip('/').split('/') for path in paths]
split_paths = [path.strip("/").split("/") for path in paths]
s1 = min(split_paths)
s2 = max(split_paths)
common = s1
@ -33,7 +37,7 @@ def common_path(paths):
if c != s2[i]:
common = s1[:i]
break
return '/' + '/'.join(common)
return "/" + "/".join(common)
def get_pk_name(model):
@ -47,7 +51,8 @@ def is_api_view(callback):
"""
# Avoid import cycle on APIView
from rest_framework.views import APIView
cls = getattr(callback, 'cls', None)
cls = getattr(callback, "cls", None)
return (cls is not None) and issubclass(cls, APIView)
@ -78,7 +83,7 @@ class LinkNode(OrderedDict):
current_val = self.methods_counter[preferred_key]
self.methods_counter[preferred_key] += 1
key = '{}_{}'.format(preferred_key, current_val)
key = "{}_{}".format(preferred_key, current_val)
if key not in self:
return key
@ -101,9 +106,7 @@ def insert_into(target, keys, value):
target.links.append((keys[-1], value))
except TypeError:
msg = INSERT_INTO_COLLISION_FMT.format(
value_url=value.url,
target_url=target.url,
keys=keys
value_url=value.url, target_url=target.url, keys=keys
)
raise ValueError(msg)
@ -119,24 +122,25 @@ def distribute_links(obj):
def is_custom_action(action):
return action not in {
'retrieve', 'list', 'create', 'update', 'partial_update', 'destroy'
"retrieve",
"list",
"create",
"update",
"partial_update",
"destroy",
}
def endpoint_ordering(endpoint):
path, method, callback = endpoint
method_priority = {
'GET': 0,
'POST': 1,
'PUT': 2,
'PATCH': 3,
'DELETE': 4
}.get(method, 5)
method_priority = {"GET": 0, "POST": 1, "PUT": 2, "PATCH": 3, "DELETE": 4}.get(
method, 5
)
return (path, method_priority)
_PATH_PARAMETER_COMPONENT_RE = re.compile(
r'<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>'
r"<(?:(?P<converter>[^>:]+):)?(?P<parameter>\w+)>"
)
@ -144,6 +148,7 @@ class EndpointEnumerator(object):
"""
A class to determine the available API endpoints that a project exposes.
"""
def __init__(self, patterns=None, urlconf=None):
if patterns is None:
if urlconf is None:
@ -159,7 +164,7 @@ class EndpointEnumerator(object):
self.patterns = patterns
def get_api_endpoints(self, patterns=None, prefix=''):
def get_api_endpoints(self, patterns=None, prefix=""):
"""
Return a list of all available API endpoints by inspecting the URL conf.
"""
@ -180,8 +185,7 @@ class EndpointEnumerator(object):
elif isinstance(pattern, URLResolver):
nested_endpoints = self.get_api_endpoints(
patterns=pattern.url_patterns,
prefix=path_regex
patterns=pattern.url_patterns, prefix=path_regex
)
api_endpoints.extend(nested_endpoints)
@ -196,7 +200,7 @@ class EndpointEnumerator(object):
path = simplify_regex(path_regex)
# Strip Django 2.0 convertors as they are incompatible with uritemplate format
path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r'{\g<parameter>}', path)
path = re.sub(_PATH_PARAMETER_COMPONENT_RE, r"{\g<parameter>}", path)
return path
def should_include_endpoint(self, path, callback):
@ -209,11 +213,11 @@ class EndpointEnumerator(object):
if callback.cls.schema is None:
return False
if 'schema' in callback.initkwargs:
if callback.initkwargs['schema'] is None:
if "schema" in callback.initkwargs:
if callback.initkwargs["schema"] is None:
return False
if path.endswith('.{format}') or path.endswith('.{format}/'):
if path.endswith(".{format}") or path.endswith(".{format}/"):
return False # Ignore .json style URLs.
return True
@ -222,24 +226,24 @@ class EndpointEnumerator(object):
"""
Return a list of the valid HTTP methods for this endpoint.
"""
if hasattr(callback, 'actions'):
if hasattr(callback, "actions"):
actions = set(callback.actions)
http_method_names = set(callback.cls.http_method_names)
methods = [method.upper() for method in actions & http_method_names]
else:
methods = callback.cls().allowed_methods
return [method for method in methods if method not in ('OPTIONS', 'HEAD')]
return [method for method in methods if method not in ("OPTIONS", "HEAD")]
class SchemaGenerator(object):
# Map HTTP methods onto actions.
default_mapping = {
'get': 'retrieve',
'post': 'create',
'put': 'update',
'patch': 'partial_update',
'delete': 'destroy',
"get": "retrieve",
"post": "create",
"put": "update",
"patch": "partial_update",
"delete": "destroy",
}
endpoint_inspector_cls = EndpointEnumerator
@ -253,12 +257,14 @@ class SchemaGenerator(object):
# Set by 'SCHEMA_COERCE_PATH_PK'.
coerce_path_pk = None
def __init__(self, title=None, url=None, description=None, patterns=None, urlconf=None):
assert coreapi, '`coreapi` must be installed for schema support.'
assert coreschema, '`coreschema` must be installed for schema support.'
def __init__(
self, title=None, url=None, description=None, patterns=None, urlconf=None
):
assert coreapi, "`coreapi` must be installed for schema support."
assert coreschema, "`coreschema` must be installed for schema support."
if url and not url.endswith('/'):
url += '/'
if url and not url.endswith("/"):
url += "/"
self.coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
self.coerce_path_pk = api_settings.SCHEMA_COERCE_PATH_PK
@ -288,8 +294,7 @@ class SchemaGenerator(object):
distribute_links(links)
return coreapi.Document(
title=self.title, description=self.description,
url=url, content=links
title=self.title, description=self.description, url=url, content=links
)
def get_links(self, request=None):
@ -317,7 +322,7 @@ class SchemaGenerator(object):
if not self.has_view_permissions(path, method, view):
continue
link = view.schema.get_link(path, method, base_url=self.url)
subpath = path[len(prefix):]
subpath = path[len(prefix) :]
keys = self.get_keys(subpath, method, view)
insert_into(links, keys, link)
@ -342,35 +347,35 @@ class SchemaGenerator(object):
"""
prefixes = []
for path in paths:
components = path.strip('/').split('/')
components = path.strip("/").split("/")
initial_components = []
for component in components:
if '{' in component:
if "{" in component:
break
initial_components.append(component)
prefix = '/'.join(initial_components[:-1])
prefix = "/".join(initial_components[:-1])
if not prefix:
# We can just break early in the case that there's at least
# one URL that doesn't have a path prefix.
return '/'
prefixes.append('/' + prefix + '/')
return "/"
prefixes.append("/" + prefix + "/")
return common_path(prefixes)
def create_view(self, callback, method, request=None):
"""
Given a callback, return an actual view instance.
"""
view = callback.cls(**getattr(callback, 'initkwargs', {}))
view = callback.cls(**getattr(callback, "initkwargs", {}))
view.args = ()
view.kwargs = {}
view.format_kwarg = None
view.request = None
view.action_map = getattr(callback, 'actions', None)
view.action_map = getattr(callback, "actions", None)
actions = getattr(callback, 'actions', None)
actions = getattr(callback, "actions", None)
if actions is not None:
if method == 'OPTIONS':
view.action = 'metadata'
if method == "OPTIONS":
view.action = "metadata"
else:
view.action = actions.get(method.lower())
@ -398,14 +403,14 @@ class SchemaGenerator(object):
where possible. This is cleaner for an external representation.
(Ie. "this is an identifier", not "this is a database primary key")
"""
if not self.coerce_path_pk or '{pk}' not in path:
if not self.coerce_path_pk or "{pk}" not in path:
return path
model = getattr(getattr(view, 'queryset', None), 'model', None)
model = getattr(getattr(view, "queryset", None), "model", None)
if model:
field_name = get_pk_name(model)
else:
field_name = 'id'
return path.replace('{pk}', '{%s}' % field_name)
field_name = "id"
return path.replace("{pk}", "{%s}" % field_name)
# Method for generating the link layout....
@ -421,20 +426,20 @@ class SchemaGenerator(object):
/users/{pk}/groups/ ("users", "groups", "list"), ("users", "groups", "create")
/users/{pk}/groups/{pk}/ ("users", "groups", "read"), ("users", "groups", "update"), ("users", "groups", "delete")
"""
if hasattr(view, 'action'):
if hasattr(view, "action"):
# Viewsets have explicitly named actions.
action = view.action
else:
# Views have no associated action, so we determine one from the method.
if is_list_view(subpath, method, view):
action = 'list'
action = "list"
else:
action = self.default_mapping[method.lower()]
named_path_components = [
component for component
in subpath.strip('/').split('/')
if '{' not in component
component
for component in subpath.strip("/").split("/")
if "{" not in component
]
if is_custom_action(action):

View File

@ -21,46 +21,38 @@ from rest_framework.utils import formatting
from .utils import is_list_view
header_regex = re.compile('^[a-zA-Z][0-9A-Za-z_]*:')
header_regex = re.compile("^[a-zA-Z][0-9A-Za-z_]*:")
def field_to_schema(field):
title = force_text(field.label) if field.label else ''
description = force_text(field.help_text) if field.help_text else ''
title = force_text(field.label) if field.label else ""
description = force_text(field.help_text) if field.help_text else ""
if isinstance(field, (serializers.ListSerializer, serializers.ListField)):
child_schema = field_to_schema(field.child)
return coreschema.Array(
items=child_schema,
title=title,
description=description
items=child_schema, title=title, description=description
)
elif isinstance(field, serializers.DictField):
return coreschema.Object(
title=title,
description=description
)
return coreschema.Object(title=title, description=description)
elif isinstance(field, serializers.Serializer):
return coreschema.Object(
properties=OrderedDict([
(key, field_to_schema(value))
for key, value
in field.fields.items()
]),
properties=OrderedDict(
[(key, field_to_schema(value)) for key, value in field.fields.items()]
),
title=title,
description=description
description=description,
)
elif isinstance(field, serializers.ManyRelatedField):
related_field_schema = field_to_schema(field.child_relation)
return coreschema.Array(
items=related_field_schema,
title=title,
description=description
items=related_field_schema, title=title, description=description
)
elif isinstance(field, serializers.PrimaryKeyRelatedField):
schema_cls = coreschema.String
model = getattr(field.queryset, 'model', None)
model = getattr(field.queryset, "model", None)
if model is not None:
model_field = model._meta.pk
if isinstance(model_field, models.AutoField):
@ -72,13 +64,11 @@ def field_to_schema(field):
return coreschema.Array(
items=coreschema.Enum(enum=list(field.choices)),
title=title,
description=description
description=description,
)
elif isinstance(field, serializers.ChoiceField):
return coreschema.Enum(
enum=list(field.choices),
title=title,
description=description
enum=list(field.choices), title=title, description=description
)
elif isinstance(field, serializers.BooleanField):
return coreschema.Boolean(title=title, description=description)
@ -87,25 +77,17 @@ def field_to_schema(field):
elif isinstance(field, serializers.IntegerField):
return coreschema.Integer(title=title, description=description)
elif isinstance(field, serializers.DateField):
return coreschema.String(
title=title,
description=description,
format='date'
)
return coreschema.String(title=title, description=description, format="date")
elif isinstance(field, serializers.DateTimeField):
return coreschema.String(
title=title,
description=description,
format='date-time'
title=title, description=description, format="date-time"
)
elif isinstance(field, serializers.JSONField):
return coreschema.Object(title=title, description=description)
if field.style.get('base_template') == 'textarea.html':
if field.style.get("base_template") == "textarea.html":
return coreschema.String(
title=title,
description=description,
format='textarea'
title=title, description=description, format="textarea"
)
return coreschema.String(title=title, description=description)
@ -113,15 +95,14 @@ def field_to_schema(field):
def get_pk_description(model, model_field):
if isinstance(model_field, models.AutoField):
value_type = _('unique integer value')
value_type = _("unique integer value")
elif isinstance(model_field, models.UUIDField):
value_type = _('UUID string')
value_type = _("UUID string")
else:
value_type = _('unique value')
value_type = _("unique value")
return _('A {value_type} identifying this {name}.').format(
value_type=value_type,
name=model._meta.verbose_name,
return _("A {value_type} identifying this {name}.").format(
value_type=value_type, name=model._meta.verbose_name
)
@ -200,6 +181,7 @@ class AutoSchema(ViewInspector):
Responsible for per-view introspection and schema generation.
"""
def __init__(self, manual_fields=None):
"""
Parameters:
@ -221,14 +203,14 @@ class AutoSchema(ViewInspector):
manual_fields = self.get_manual_fields(path, method)
fields = self.update_fields(fields, manual_fields)
if fields and any([field.location in ('form', 'body') for field in fields]):
if fields and any([field.location in ("form", "body") for field in fields]):
encoding = self.get_encoding(path, method)
else:
encoding = None
description = self.get_description(path, method)
if base_url and path.startswith('/'):
if base_url and path.startswith("/"):
path = path[1:]
return coreapi.Link(
@ -236,7 +218,7 @@ class AutoSchema(ViewInspector):
action=method.lower(),
encoding=encoding,
fields=fields,
description=description
description=description,
)
def get_description(self, path, method):
@ -248,25 +230,31 @@ class AutoSchema(ViewInspector):
"""
view = self.view
method_name = getattr(view, 'action', method.lower())
method_name = getattr(view, "action", method.lower())
method_docstring = getattr(view, method_name, None).__doc__
if method_docstring:
# An explicit docstring on the method or action.
return self._get_description_section(view, method.lower(), formatting.dedent(smart_text(method_docstring)))
return self._get_description_section(
view, method.lower(), formatting.dedent(smart_text(method_docstring))
)
else:
return self._get_description_section(view, getattr(view, 'action', method.lower()), view.get_view_description())
return self._get_description_section(
view,
getattr(view, "action", method.lower()),
view.get_view_description(),
)
def _get_description_section(self, view, header, description):
lines = [line for line in description.splitlines()]
current_section = ''
sections = {'': ''}
current_section = ""
sections = {"": ""}
for line in lines:
if header_regex.match(line):
current_section, seperator, lead = line.partition(':')
current_section, seperator, lead = line.partition(":")
sections[current_section] = lead.strip()
else:
sections[current_section] += '\n' + line
sections[current_section] += "\n" + line
# TODO: SCHEMA_COERCE_METHOD_NAMES appears here and in `SchemaGenerator.get_keys`
coerce_method_names = api_settings.SCHEMA_COERCE_METHOD_NAMES
@ -275,7 +263,7 @@ class AutoSchema(ViewInspector):
if header in coerce_method_names:
if coerce_method_names[header] in sections:
return sections[coerce_method_names[header]].strip()
return sections[''].strip()
return sections[""].strip()
def get_path_fields(self, path, method):
"""
@ -283,12 +271,12 @@ class AutoSchema(ViewInspector):
templated path variables.
"""
view = self.view
model = getattr(getattr(view, 'queryset', None), 'model', None)
model = getattr(getattr(view, "queryset", None), "model", None)
fields = []
for variable in uritemplate.variables(path):
title = ''
description = ''
title = ""
description = ""
schema_cls = coreschema.String
kwargs = {}
if model is not None:
@ -306,16 +294,19 @@ class AutoSchema(ViewInspector):
elif model_field is not None and model_field.primary_key:
description = get_pk_description(model, model_field)
if hasattr(view, 'lookup_value_regex') and view.lookup_field == variable:
kwargs['pattern'] = view.lookup_value_regex
if (
hasattr(view, "lookup_value_regex")
and view.lookup_field == variable
):
kwargs["pattern"] = view.lookup_value_regex
elif isinstance(model_field, models.AutoField):
schema_cls = coreschema.Integer
field = coreapi.Field(
name=variable,
location='path',
location="path",
required=True,
schema=schema_cls(title=title, description=description, **kwargs)
schema=schema_cls(title=title, description=description, **kwargs),
)
fields.append(field)
@ -328,28 +319,29 @@ class AutoSchema(ViewInspector):
"""
view = self.view
if method not in ('PUT', 'PATCH', 'POST'):
if method not in ("PUT", "PATCH", "POST"):
return []
if not hasattr(view, 'get_serializer'):
if not hasattr(view, "get_serializer"):
return []
try:
serializer = view.get_serializer()
except exceptions.APIException:
serializer = None
warnings.warn('{}.get_serializer() raised an exception during '
'schema generation. Serializer fields will not be '
'generated for {} {}.'
.format(view.__class__.__name__, method, path))
warnings.warn(
"{}.get_serializer() raised an exception during "
"schema generation. Serializer fields will not be "
"generated for {} {}.".format(view.__class__.__name__, method, path)
)
if isinstance(serializer, serializers.ListSerializer):
return [
coreapi.Field(
name='data',
location='body',
name="data",
location="body",
required=True,
schema=coreschema.Array()
schema=coreschema.Array(),
)
]
@ -361,12 +353,12 @@ class AutoSchema(ViewInspector):
if field.read_only or isinstance(field, serializers.HiddenField):
continue
required = field.required and method != 'PATCH'
required = field.required and method != "PATCH"
field = coreapi.Field(
name=field.field_name,
location='form',
location="form",
required=required,
schema=field_to_schema(field)
schema=field_to_schema(field),
)
fields.append(field)
@ -378,7 +370,7 @@ class AutoSchema(ViewInspector):
if not is_list_view(path, method, view):
return []
pagination = getattr(view, 'pagination_class', None)
pagination = getattr(view, "pagination_class", None)
if not pagination:
return []
@ -397,11 +389,17 @@ class AutoSchema(ViewInspector):
Note: Introduced in v3.7: Initially "private" (i.e. with leading underscore)
to allow changes based on user experience.
"""
if getattr(self.view, 'filter_backends', None) is None:
if getattr(self.view, "filter_backends", None) is None:
return False
if hasattr(self.view, 'action'):
return self.view.action in ["list", "retrieve", "update", "partial_update", "destroy"]
if hasattr(self.view, "action"):
return self.view.action in [
"list",
"retrieve",
"update",
"partial_update",
"destroy",
]
return method.lower() in ["get", "put", "patch", "delete"]
@ -447,18 +445,18 @@ class AutoSchema(ViewInspector):
# Core API supports the following request encodings over HTTP...
supported_media_types = {
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data',
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
}
parser_classes = getattr(view, 'parser_classes', [])
parser_classes = getattr(view, "parser_classes", [])
for parser_class in parser_classes:
media_type = getattr(parser_class, 'media_type', None)
media_type = getattr(parser_class, "media_type", None)
if media_type in supported_media_types:
return media_type
# Raw binary uploads are supported with "application/octet-stream"
if media_type == '*/*':
return 'application/octet-stream'
if media_type == "*/*":
return "application/octet-stream"
return None
@ -468,7 +466,8 @@ class ManualSchema(ViewInspector):
Allows providing a list of coreapi.Fields,
plus an optional description.
"""
def __init__(self, fields, description='', encoding=None):
def __init__(self, fields, description="", encoding=None):
"""
Parameters:
@ -476,14 +475,16 @@ class ManualSchema(ViewInspector):
* `description`: String description for view. Optional.
"""
super(ManualSchema, self).__init__()
assert all(isinstance(f, coreapi.Field) for f in fields), "`fields` must be a list of coreapi.Field instances"
assert all(
isinstance(f, coreapi.Field) for f in fields
), "`fields` must be a list of coreapi.Field instances"
self._fields = fields
self._description = description
self._encoding = encoding
def get_link(self, path, method, base_url):
if base_url and path.startswith('/'):
if base_url and path.startswith("/"):
path = path[1:]
return coreapi.Link(
@ -491,21 +492,22 @@ class ManualSchema(ViewInspector):
action=method.lower(),
encoding=self._encoding,
fields=self._fields,
description=self._description
description=self._description,
)
class DefaultSchema(ViewInspector):
"""Allows overriding AutoSchema using DEFAULT_SCHEMA_CLASS setting"""
def __get__(self, instance, owner):
result = super(DefaultSchema, self).__get__(instance, owner)
if not isinstance(result, DefaultSchema):
return result
inspector_class = api_settings.DEFAULT_SCHEMA_CLASS
assert issubclass(inspector_class, ViewInspector), (
"DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
)
assert issubclass(
inspector_class, ViewInspector
), "DEFAULT_SCHEMA_CLASS must be set to a ViewInspector (usually an AutoSchema) subclass"
inspector = inspector_class()
inspector.view = instance
return inspector

View File

@ -10,15 +10,15 @@ def is_list_view(path, method, view):
"""
Return True if the given path/method appears to represent a list view.
"""
if hasattr(view, 'action'):
if hasattr(view, "action"):
# Viewsets have an explicitly defined action, which we can inspect.
return view.action == 'list'
return view.action == "list"
if method.lower() != 'get':
if method.lower() != "get":
return False
if isinstance(view, RetrieveModelMixin):
return False
path_components = path.strip('/').split('/')
if path_components and '{' in path_components[-1]:
path_components = path.strip("/").split("/")
if path_components and "{" in path_components[-1]:
return False
return True

View File

@ -21,7 +21,7 @@ class SchemaView(APIView):
if self.renderer_classes is None:
self.renderer_classes = [
renderers.OpenAPIRenderer,
renderers.CoreJSONRenderer
renderers.CoreJSONRenderer,
]
if renderers.BrowsableAPIRenderer in api_settings.DEFAULT_RENDERER_CLASSES:
self.renderer_classes += [renderers.BrowsableAPIRenderer]

File diff suppressed because it is too large Load Diff

View File

@ -28,135 +28,109 @@ from django.utils import six
from rest_framework import ISO_8601
DEFAULTS = {
# Base API policies
'DEFAULT_RENDERER_CLASSES': (
'rest_framework.renderers.JSONRenderer',
'rest_framework.renderers.BrowsableAPIRenderer',
"DEFAULT_RENDERER_CLASSES": (
"rest_framework.renderers.JSONRenderer",
"rest_framework.renderers.BrowsableAPIRenderer",
),
'DEFAULT_PARSER_CLASSES': (
'rest_framework.parsers.JSONParser',
'rest_framework.parsers.FormParser',
'rest_framework.parsers.MultiPartParser'
"DEFAULT_PARSER_CLASSES": (
"rest_framework.parsers.JSONParser",
"rest_framework.parsers.FormParser",
"rest_framework.parsers.MultiPartParser",
),
'DEFAULT_AUTHENTICATION_CLASSES': (
'rest_framework.authentication.SessionAuthentication',
'rest_framework.authentication.BasicAuthentication'
"DEFAULT_AUTHENTICATION_CLASSES": (
"rest_framework.authentication.SessionAuthentication",
"rest_framework.authentication.BasicAuthentication",
),
'DEFAULT_PERMISSION_CLASSES': (
'rest_framework.permissions.AllowAny',
),
'DEFAULT_THROTTLE_CLASSES': (),
'DEFAULT_CONTENT_NEGOTIATION_CLASS': 'rest_framework.negotiation.DefaultContentNegotiation',
'DEFAULT_METADATA_CLASS': 'rest_framework.metadata.SimpleMetadata',
'DEFAULT_VERSIONING_CLASS': None,
"DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.AllowAny",),
"DEFAULT_THROTTLE_CLASSES": (),
"DEFAULT_CONTENT_NEGOTIATION_CLASS": "rest_framework.negotiation.DefaultContentNegotiation",
"DEFAULT_METADATA_CLASS": "rest_framework.metadata.SimpleMetadata",
"DEFAULT_VERSIONING_CLASS": None,
# Generic view behavior
'DEFAULT_PAGINATION_CLASS': None,
'DEFAULT_FILTER_BACKENDS': (),
"DEFAULT_PAGINATION_CLASS": None,
"DEFAULT_FILTER_BACKENDS": (),
# Schema
'DEFAULT_SCHEMA_CLASS': 'rest_framework.schemas.AutoSchema',
"DEFAULT_SCHEMA_CLASS": "rest_framework.schemas.AutoSchema",
# Throttling
'DEFAULT_THROTTLE_RATES': {
'user': None,
'anon': None,
},
'NUM_PROXIES': None,
"DEFAULT_THROTTLE_RATES": {"user": None, "anon": None},
"NUM_PROXIES": None,
# Pagination
'PAGE_SIZE': None,
"PAGE_SIZE": None,
# Filtering
'SEARCH_PARAM': 'search',
'ORDERING_PARAM': 'ordering',
"SEARCH_PARAM": "search",
"ORDERING_PARAM": "ordering",
# Versioning
'DEFAULT_VERSION': None,
'ALLOWED_VERSIONS': None,
'VERSION_PARAM': 'version',
"DEFAULT_VERSION": None,
"ALLOWED_VERSIONS": None,
"VERSION_PARAM": "version",
# Authentication
'UNAUTHENTICATED_USER': 'django.contrib.auth.models.AnonymousUser',
'UNAUTHENTICATED_TOKEN': None,
"UNAUTHENTICATED_USER": "django.contrib.auth.models.AnonymousUser",
"UNAUTHENTICATED_TOKEN": None,
# View configuration
'VIEW_NAME_FUNCTION': 'rest_framework.views.get_view_name',
'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
"VIEW_NAME_FUNCTION": "rest_framework.views.get_view_name",
"VIEW_DESCRIPTION_FUNCTION": "rest_framework.views.get_view_description",
# Exception handling
'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler',
'NON_FIELD_ERRORS_KEY': 'non_field_errors',
"EXCEPTION_HANDLER": "rest_framework.views.exception_handler",
"NON_FIELD_ERRORS_KEY": "non_field_errors",
# Testing
'TEST_REQUEST_RENDERER_CLASSES': (
'rest_framework.renderers.MultiPartRenderer',
'rest_framework.renderers.JSONRenderer'
"TEST_REQUEST_RENDERER_CLASSES": (
"rest_framework.renderers.MultiPartRenderer",
"rest_framework.renderers.JSONRenderer",
),
'TEST_REQUEST_DEFAULT_FORMAT': 'multipart',
"TEST_REQUEST_DEFAULT_FORMAT": "multipart",
# Hyperlink settings
'URL_FORMAT_OVERRIDE': 'format',
'FORMAT_SUFFIX_KWARG': 'format',
'URL_FIELD_NAME': 'url',
"URL_FORMAT_OVERRIDE": "format",
"FORMAT_SUFFIX_KWARG": "format",
"URL_FIELD_NAME": "url",
# Input and output formats
'DATE_FORMAT': ISO_8601,
'DATE_INPUT_FORMATS': (ISO_8601,),
'DATETIME_FORMAT': ISO_8601,
'DATETIME_INPUT_FORMATS': (ISO_8601,),
'TIME_FORMAT': ISO_8601,
'TIME_INPUT_FORMATS': (ISO_8601,),
"DATE_FORMAT": ISO_8601,
"DATE_INPUT_FORMATS": (ISO_8601,),
"DATETIME_FORMAT": ISO_8601,
"DATETIME_INPUT_FORMATS": (ISO_8601,),
"TIME_FORMAT": ISO_8601,
"TIME_INPUT_FORMATS": (ISO_8601,),
# Encoding
'UNICODE_JSON': True,
'COMPACT_JSON': True,
'STRICT_JSON': True,
'COERCE_DECIMAL_TO_STRING': True,
'UPLOADED_FILES_USE_URL': True,
"UNICODE_JSON": True,
"COMPACT_JSON": True,
"STRICT_JSON": True,
"COERCE_DECIMAL_TO_STRING": True,
"UPLOADED_FILES_USE_URL": True,
# Browseable API
'HTML_SELECT_CUTOFF': 1000,
'HTML_SELECT_CUTOFF_TEXT': "More than {count} items...",
"HTML_SELECT_CUTOFF": 1000,
"HTML_SELECT_CUTOFF_TEXT": "More than {count} items...",
# Schemas
'SCHEMA_COERCE_PATH_PK': True,
'SCHEMA_COERCE_METHOD_NAMES': {
'retrieve': 'read',
'destroy': 'delete'
},
"SCHEMA_COERCE_PATH_PK": True,
"SCHEMA_COERCE_METHOD_NAMES": {"retrieve": "read", "destroy": "delete"},
}
# List of settings that may be in string import notation.
IMPORT_STRINGS = (
'DEFAULT_RENDERER_CLASSES',
'DEFAULT_PARSER_CLASSES',
'DEFAULT_AUTHENTICATION_CLASSES',
'DEFAULT_PERMISSION_CLASSES',
'DEFAULT_THROTTLE_CLASSES',
'DEFAULT_CONTENT_NEGOTIATION_CLASS',
'DEFAULT_METADATA_CLASS',
'DEFAULT_VERSIONING_CLASS',
'DEFAULT_PAGINATION_CLASS',
'DEFAULT_FILTER_BACKENDS',
'DEFAULT_SCHEMA_CLASS',
'EXCEPTION_HANDLER',
'TEST_REQUEST_RENDERER_CLASSES',
'UNAUTHENTICATED_USER',
'UNAUTHENTICATED_TOKEN',
'VIEW_NAME_FUNCTION',
'VIEW_DESCRIPTION_FUNCTION'
"DEFAULT_RENDERER_CLASSES",
"DEFAULT_PARSER_CLASSES",
"DEFAULT_AUTHENTICATION_CLASSES",
"DEFAULT_PERMISSION_CLASSES",
"DEFAULT_THROTTLE_CLASSES",
"DEFAULT_CONTENT_NEGOTIATION_CLASS",
"DEFAULT_METADATA_CLASS",
"DEFAULT_VERSIONING_CLASS",
"DEFAULT_PAGINATION_CLASS",
"DEFAULT_FILTER_BACKENDS",
"DEFAULT_SCHEMA_CLASS",
"EXCEPTION_HANDLER",
"TEST_REQUEST_RENDERER_CLASSES",
"UNAUTHENTICATED_USER",
"UNAUTHENTICATED_TOKEN",
"VIEW_NAME_FUNCTION",
"VIEW_DESCRIPTION_FUNCTION",
)
# List of settings that have been removed
REMOVED_SETTINGS = (
"PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY",
)
REMOVED_SETTINGS = ("PAGINATE_BY", "PAGINATE_BY_PARAM", "MAX_PAGINATE_BY")
def perform_import(val, setting_name):
@ -179,11 +153,16 @@ def import_from_string(val, setting_name):
"""
try:
# Nod to tastypie's use of importlib.
module_path, class_name = val.rsplit('.', 1)
module_path, class_name = val.rsplit(".", 1)
module = import_module(module_path)
return getattr(module, class_name)
except (ImportError, AttributeError) as e:
msg = "Could not import '%s' for API setting '%s'. %s: %s." % (val, setting_name, e.__class__.__name__, e)
msg = "Could not import '%s' for API setting '%s'. %s: %s." % (
val,
setting_name,
e.__class__.__name__,
e,
)
raise ImportError(msg)
@ -198,6 +177,7 @@ class APISettings(object):
Any setting with string import paths will be automatically resolved
and return the class, rather than the string literal.
"""
def __init__(self, user_settings=None, defaults=None, import_strings=None):
if user_settings:
self._user_settings = self.__check_user_settings(user_settings)
@ -207,8 +187,8 @@ class APISettings(object):
@property
def user_settings(self):
if not hasattr(self, '_user_settings'):
self._user_settings = getattr(settings, 'REST_FRAMEWORK', {})
if not hasattr(self, "_user_settings"):
self._user_settings = getattr(settings, "REST_FRAMEWORK", {})
return self._user_settings
def __getattr__(self, attr):
@ -235,23 +215,26 @@ class APISettings(object):
SETTINGS_DOC = "https://www.django-rest-framework.org/api-guide/settings/"
for setting in REMOVED_SETTINGS:
if setting in user_settings:
raise RuntimeError("The '%s' setting has been removed. Please refer to '%s' for available settings." % (setting, SETTINGS_DOC))
raise RuntimeError(
"The '%s' setting has been removed. Please refer to '%s' for available settings."
% (setting, SETTINGS_DOC)
)
return user_settings
def reload(self):
for attr in self._cached_attrs:
delattr(self, attr)
self._cached_attrs.clear()
if hasattr(self, '_user_settings'):
delattr(self, '_user_settings')
if hasattr(self, "_user_settings"):
delattr(self, "_user_settings")
api_settings = APISettings(None, DEFAULTS, IMPORT_STRINGS)
def reload_api_settings(*args, **kwargs):
setting = kwargs['setting']
if setting == 'REST_FRAMEWORK':
setting = kwargs["setting"]
if setting == "REST_FRAMEWORK":
api_settings.reload()

View File

@ -15,22 +15,23 @@ from rest_framework.compat import apply_markdown, pygments_highlight
from rest_framework.renderers import HTMLFormRenderer
from rest_framework.utils.urls import replace_query_param
register = template.Library()
# Regex for adding classes to html snippets
class_re = re.compile(r'(?<=class=["\'])(.*)(?=["\'])')
@register.tag(name='code')
@register.tag(name="code")
def highlight_code(parser, token):
code = token.split_contents()[-1]
nodelist = parser.parse(('endcode',))
nodelist = parser.parse(("endcode",))
parser.delete_first_token()
return CodeNode(code, nodelist)
class CodeNode(template.Node):
style = 'emacs'
style = "emacs"
def __init__(self, lang, code):
self.lang = lang
@ -43,24 +44,17 @@ class CodeNode(template.Node):
@register.filter()
def with_location(fields, location):
return [
field for field in fields
if field.location == location
]
return [field for field in fields if field.location == location]
@register.simple_tag
def form_for_link(link):
import coreschema
properties = OrderedDict([
(field.name, field.schema or coreschema.String())
for field in link.fields
])
required = [
field.name
for field in link.fields
if field.required
]
properties = OrderedDict(
[(field.name, field.schema or coreschema.String()) for field in link.fields]
)
required = [field.name for field in link.fields if field.required]
schema = coreschema.Object(properties=properties, required=required)
return mark_safe(coreschema.render_to_form(schema))
@ -79,14 +73,14 @@ def get_pagination_html(pager):
@register.simple_tag
def render_form(serializer, template_pack=None):
style = {'template_pack': template_pack} if template_pack else {}
style = {"template_pack": template_pack} if template_pack else {}
renderer = HTMLFormRenderer()
return renderer.render(serializer.data, None, {'style': style})
return renderer.render(serializer.data, None, {"style": style})
@register.simple_tag
def render_field(field, style):
renderer = style.get('renderer', HTMLFormRenderer())
renderer = style.get("renderer", HTMLFormRenderer())
return renderer.render_field(field, style)
@ -96,9 +90,9 @@ def optional_login(request):
Include a login snippet if REST framework's login view is in the URLconf.
"""
try:
login_url = reverse('rest_framework:login')
login_url = reverse("rest_framework:login")
except NoReverseMatch:
return ''
return ""
snippet = "<li><a href='{href}?next={next}'>Log in</a></li>"
snippet = format_html(snippet, href=login_url, next=escape(request.path))
@ -112,9 +106,9 @@ def optional_docs_login(request):
Include a login snippet if REST framework's login view is in the URLconf.
"""
try:
login_url = reverse('rest_framework:login')
login_url = reverse("rest_framework:login")
except NoReverseMatch:
return 'log in'
return "log in"
snippet = "<a href='{href}?next={next}'>log in</a>"
snippet = format_html(snippet, href=login_url, next=escape(request.path))
@ -128,7 +122,7 @@ def optional_logout(request, user):
Include a logout snippet if REST framework's logout view is in the URLconf.
"""
try:
logout_url = reverse('rest_framework:logout')
logout_url = reverse("rest_framework:logout")
except NoReverseMatch:
snippet = format_html('<li class="navbar-text">{user}</li>', user=escape(user))
return mark_safe(snippet)
@ -142,7 +136,9 @@ def optional_logout(request, user):
<li><a href='{href}?next={next}'>Log out</a></li>
</ul>
</li>"""
snippet = format_html(snippet, user=escape(user), href=logout_url, next=escape(request.path))
snippet = format_html(
snippet, user=escape(user), href=logout_url, next=escape(request.path)
)
return mark_safe(snippet)
@ -160,16 +156,13 @@ def add_query_param(request, key, val):
@register.filter
def as_string(value):
if value is None:
return ''
return '%s' % value
return ""
return "%s" % value
@register.filter
def as_list_of_strings(value):
return [
'' if (item is None) else ('%s' % item)
for item in value
]
return ["" if (item is None) else ("%s" % item) for item in value]
@register.filter
@ -190,45 +183,52 @@ def add_class(value, css_class):
html = six.text_type(value)
match = class_re.search(html)
if match:
m = re.search(r'^%s$|^%s\s|\s%s\s|\s%s$' % (css_class, css_class,
css_class, css_class),
match.group(1))
m = re.search(
r"^%s$|^%s\s|\s%s\s|\s%s$" % (css_class, css_class, css_class, css_class),
match.group(1),
)
if not m:
return mark_safe(class_re.sub(match.group(1) + " " + css_class,
html))
return mark_safe(class_re.sub(match.group(1) + " " + css_class, html))
else:
return mark_safe(html.replace('>', ' class="%s">' % css_class, 1))
return mark_safe(html.replace(">", ' class="%s">' % css_class, 1))
return value
@register.filter
def format_value(value):
if getattr(value, 'is_hyperlink', False):
if getattr(value, "is_hyperlink", False):
name = six.text_type(value.obj)
return mark_safe('<a href=%s>%s</a>' % (value, escape(name)))
return mark_safe("<a href=%s>%s</a>" % (value, escape(name)))
if value is None or isinstance(value, bool):
return mark_safe('<code>%s</code>' % {True: 'true', False: 'false', None: 'null'}[value])
return mark_safe(
"<code>%s</code>" % {True: "true", False: "false", None: "null"}[value]
)
elif isinstance(value, list):
if any([isinstance(item, (list, dict)) for item in value]):
template = loader.get_template('rest_framework/admin/list_value.html')
template = loader.get_template("rest_framework/admin/list_value.html")
else:
template = loader.get_template('rest_framework/admin/simple_list_value.html')
context = {'value': value}
template = loader.get_template(
"rest_framework/admin/simple_list_value.html"
)
context = {"value": value}
return template.render(context)
elif isinstance(value, dict):
template = loader.get_template('rest_framework/admin/dict_value.html')
context = {'value': value}
template = loader.get_template("rest_framework/admin/dict_value.html")
context = {"value": value}
return template.render(context)
elif isinstance(value, six.string_types):
if (
(value.startswith('http:') or value.startswith('https:')) and not
re.search(r'\s', value)
if (value.startswith("http:") or value.startswith("https:")) and not re.search(
r"\s", value
):
return mark_safe('<a href="{value}">{value}</a>'.format(value=escape(value)))
elif '@' in value and not re.search(r'\s', value):
return mark_safe('<a href="mailto:{value}">{value}</a>'.format(value=escape(value)))
elif '\n' in value:
return mark_safe('<pre>%s</pre>' % escape(value))
return mark_safe(
'<a href="{value}">{value}</a>'.format(value=escape(value))
)
elif "@" in value and not re.search(r"\s", value):
return mark_safe(
'<a href="mailto:{value}">{value}</a>'.format(value=escape(value))
)
elif "\n" in value:
return mark_safe("<pre>%s</pre>" % escape(value))
return six.text_type(value)
@ -266,7 +266,7 @@ def schema_links(section, sec_key=None):
"""
Recursively find every link in a schema, even nested.
"""
NESTED_FORMAT = '%s > %s' # this format is used in docs/js/api.js:normalizeKeys
NESTED_FORMAT = "%s > %s" # this format is used in docs/js/api.js:normalizeKeys
links = section.links
if section.data:
data = section.data.items()
@ -287,20 +287,30 @@ def schema_links(section, sec_key=None):
@register.filter
def add_nested_class(value):
if isinstance(value, dict):
return 'class=nested'
if isinstance(value, list) and any([isinstance(item, (list, dict)) for item in value]):
return 'class=nested'
return ''
return "class=nested"
if isinstance(value, list) and any(
[isinstance(item, (list, dict)) for item in value]
):
return "class=nested"
return ""
# Bunch of stuff cloned from urlize
TRAILING_PUNCTUATION = ['.', ',', ':', ';', '.)', '"', "']", "'}", "'"]
WRAPPING_PUNCTUATION = [('(', ')'), ('<', '>'), ('[', ']'), ('&lt;', '&gt;'),
('"', '"'), ("'", "'")]
word_split_re = re.compile(r'(\s+)')
simple_url_re = re.compile(r'^https?://\[?\w', re.IGNORECASE)
simple_url_2_re = re.compile(r'^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$', re.IGNORECASE)
simple_email_re = re.compile(r'^\S+@\S+\.\S+$')
TRAILING_PUNCTUATION = [".", ",", ":", ";", ".)", '"', "']", "'}", "'"]
WRAPPING_PUNCTUATION = [
("(", ")"),
("<", ">"),
("[", "]"),
("&lt;", "&gt;"),
('"', '"'),
("'", "'"),
]
word_split_re = re.compile(r"(\s+)")
simple_url_re = re.compile(r"^https?://\[?\w", re.IGNORECASE)
simple_url_2_re = re.compile(
r"^www\.|^(?!http)\w[^@]+\.(com|edu|gov|int|mil|net|org)$", re.IGNORECASE
)
simple_email_re = re.compile(r"^\S+@\S+\.\S+$")
def smart_urlquote_wrapper(matched_url):
@ -332,8 +342,13 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
If autoescape is True, the link text and URLs will get autoescaped.
"""
def trim_url(x, limit=trim_url_limit):
return limit is not None and (len(x) > limit and ('%s...' % x[:max(0, limit - 3)])) or x
return (
limit is not None
and (len(x) > limit and ("%s..." % x[: max(0, limit - 3)]))
or x
)
safe_input = isinstance(text, SafeData)
@ -344,40 +359,40 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
words = word_split_re.split(force_text(text))
for i, word in enumerate(words):
if '.' in word or '@' in word or ':' in word:
if "." in word or "@" in word or ":" in word:
# Deal with punctuation.
lead, middle, trail = '', word, ''
lead, middle, trail = "", word, ""
for punctuation in TRAILING_PUNCTUATION:
if middle.endswith(punctuation):
middle = middle[:-len(punctuation)]
middle = middle[: -len(punctuation)]
trail = punctuation + trail
for opening, closing in WRAPPING_PUNCTUATION:
if middle.startswith(opening):
middle = middle[len(opening):]
middle = middle[len(opening) :]
lead = lead + opening
# Keep parentheses at the end only if they're balanced.
if (
middle.endswith(closing) and
middle.count(closing) == middle.count(opening) + 1
middle.endswith(closing)
and middle.count(closing) == middle.count(opening) + 1
):
middle = middle[:-len(closing)]
middle = middle[: -len(closing)]
trail = closing + trail
# Make URL we want to point to.
url = None
nofollow_attr = ' rel="nofollow"' if nofollow else ''
nofollow_attr = ' rel="nofollow"' if nofollow else ""
if simple_url_re.match(middle):
url = smart_urlquote_wrapper(middle)
elif simple_url_2_re.match(middle):
url = smart_urlquote_wrapper('http://%s' % middle)
elif ':' not in middle and simple_email_re.match(middle):
local, domain = middle.rsplit('@', 1)
url = smart_urlquote_wrapper("http://%s" % middle)
elif ":" not in middle and simple_email_re.match(middle):
local, domain = middle.rsplit("@", 1)
try:
domain = domain.encode('idna').decode('ascii')
domain = domain.encode("idna").decode("ascii")
except UnicodeError:
continue
url = 'mailto:%s@%s' % (local, domain)
nofollow_attr = ''
url = "mailto:%s@%s" % (local, domain)
nofollow_attr = ""
# Make link.
if url:
@ -385,12 +400,12 @@ def urlize_quoted_links(text, trim_url_limit=None, nofollow=True, autoescape=Tru
lead, trail = conditional_escape(lead), conditional_escape(trail)
url, trimmed = conditional_escape(url), conditional_escape(trimmed)
middle = '<a href="%s"%s>%s</a>' % (url, nofollow_attr, trimmed)
words[i] = '%s%s%s' % (lead, middle, trail)
words[i] = "%s%s%s" % (lead, middle, trail)
else:
words[i] = conditional_escape(word)
else:
words[i] = conditional_escape(word)
return mark_safe(''.join(words))
return mark_safe("".join(words))
@register.filter
@ -399,6 +414,6 @@ def break_long_headers(header):
Breaks headers longer than 160 characters (~page length)
when possible (are comma separated)
"""
if len(header) > 160 and ',' in header:
header = mark_safe('<br> ' + ', <br>'.join(header.split(',')))
if len(header) > 160 and "," in header:
header = mark_safe("<br> " + ", <br>".join(header.split(",")))
return header

View File

@ -11,9 +11,11 @@ from django.conf import settings
from django.core.exceptions import ImproperlyConfigured
from django.core.handlers.wsgi import WSGIHandler
from django.test import override_settings, testcases
from django.test.client import Client as DjangoClient
from django.test.client import ClientHandler
from django.test.client import RequestFactory as DjangoRequestFactory
from django.test.client import (
Client as DjangoClient,
ClientHandler,
RequestFactory as DjangoRequestFactory,
)
from django.utils import six
from django.utils.encoding import force_bytes
from django.utils.http import urlencode
@ -28,6 +30,7 @@ def force_authenticate(request, user=None, token=None):
if requests is not None:
class HeaderDict(requests.packages.urllib3._collections.HTTPHeaderDict):
def get_all(self, key, default):
return self.getheaders(key)
@ -48,6 +51,7 @@ if requests is not None:
A transport adapter for `requests`, that makes requests via the
Django WSGI app, rather than making actual HTTP requests over the network.
"""
def __init__(self):
self.app = WSGIHandler()
self.factory = DjangoRequestFactory()
@ -62,19 +66,19 @@ if requests is not None:
# Set request content, if any exists.
if request.body is not None:
if hasattr(request.body, 'read'):
kwargs['data'] = request.body.read()
if hasattr(request.body, "read"):
kwargs["data"] = request.body.read()
else:
kwargs['data'] = request.body
if 'content-type' in request.headers:
kwargs['content_type'] = request.headers['content-type']
kwargs["data"] = request.body
if "content-type" in request.headers:
kwargs["content_type"] = request.headers["content-type"]
# Set request headers.
for key, value in request.headers.items():
key = key.upper()
if key in ('CONNECTION', 'CONTENT-LENGTH', 'CONTENT-TYPE'):
if key in ("CONNECTION", "CONTENT-LENGTH", "CONTENT-TYPE"):
continue
kwargs['HTTP_%s' % key.replace('-', '_')] = value
kwargs["HTTP_%s" % key.replace("-", "_")] = value
return self.factory.generic(method, url, **kwargs).environ
@ -85,20 +89,20 @@ if requests is not None:
raw_kwargs = {}
def start_response(wsgi_status, wsgi_headers):
status, _, reason = wsgi_status.partition(' ')
raw_kwargs['status'] = int(status)
raw_kwargs['reason'] = reason
raw_kwargs['headers'] = wsgi_headers
raw_kwargs['version'] = 11
raw_kwargs['preload_content'] = False
raw_kwargs['original_response'] = MockOriginalResponse(wsgi_headers)
status, _, reason = wsgi_status.partition(" ")
raw_kwargs["status"] = int(status)
raw_kwargs["reason"] = reason
raw_kwargs["headers"] = wsgi_headers
raw_kwargs["version"] = 11
raw_kwargs["preload_content"] = False
raw_kwargs["original_response"] = MockOriginalResponse(wsgi_headers)
# Make the outgoing request via WSGI.
environ = self.get_environ(request)
wsgi_response = self.app(environ, start_response)
# Build the underlying urllib3.HTTPResponse
raw_kwargs['body'] = io.BytesIO(b''.join(wsgi_response))
raw_kwargs["body"] = io.BytesIO(b"".join(wsgi_response))
raw = requests.packages.urllib3.HTTPResponse(**raw_kwargs)
# Build the requests.Response
@ -111,33 +115,47 @@ if requests is not None:
def __init__(self, *args, **kwargs):
super(RequestsClient, self).__init__(*args, **kwargs)
adapter = DjangoTestAdapter()
self.mount('http://', adapter)
self.mount('https://', adapter)
self.mount("http://", adapter)
self.mount("https://", adapter)
def request(self, method, url, *args, **kwargs):
if not url.startswith('http'):
raise ValueError('Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"' % url)
if not url.startswith("http"):
raise ValueError(
'Missing "http:" or "https:". Use a fully qualified URL, eg "http://testserver%s"'
% url
)
return super(RequestsClient, self).request(method, url, *args, **kwargs)
else:
def RequestsClient(*args, **kwargs):
raise ImproperlyConfigured('requests must be installed in order to use RequestsClient.')
raise ImproperlyConfigured(
"requests must be installed in order to use RequestsClient."
)
if coreapi is not None:
class CoreAPIClient(coreapi.Client):
def __init__(self, *args, **kwargs):
self._session = RequestsClient()
kwargs['transports'] = [coreapi.transports.HTTPTransport(session=self.session)]
kwargs["transports"] = [
coreapi.transports.HTTPTransport(session=self.session)
]
return super(CoreAPIClient, self).__init__(*args, **kwargs)
@property
def session(self):
return self._session
else:
def CoreAPIClient(*args, **kwargs):
raise ImproperlyConfigured('coreapi must be installed in order to use CoreAPIClient.')
raise ImproperlyConfigured(
"coreapi must be installed in order to use CoreAPIClient."
)
class APIRequestFactory(DjangoRequestFactory):
@ -157,11 +175,11 @@ class APIRequestFactory(DjangoRequestFactory):
"""
if data is None:
return ('', content_type)
return ("", content_type)
assert format is None or content_type is None, (
'You may not set both `format` and `content_type`.'
)
assert (
format is None or content_type is None
), "You may not set both `format` and `content_type`."
if content_type:
# Content type specified explicitly, treat data as a raw bytestring
@ -175,7 +193,7 @@ class APIRequestFactory(DjangoRequestFactory):
"Set TEST_REQUEST_RENDERER_CLASSES to enable "
"extra request formats.".format(
format,
', '.join(["'" + fmt + "'" for fmt in self.renderer_classes])
", ".join(["'" + fmt + "'" for fmt in self.renderer_classes]),
)
)
@ -195,47 +213,53 @@ class APIRequestFactory(DjangoRequestFactory):
return ret, content_type
def get(self, path, data=None, **extra):
r = {
'QUERY_STRING': urlencode(data or {}, doseq=True),
}
if not data and '?' in path:
r = {"QUERY_STRING": urlencode(data or {}, doseq=True)}
if not data and "?" in path:
# Fix to support old behavior where you have the arguments in the
# url. See #1461.
query_string = force_bytes(path.split('?')[1])
query_string = force_bytes(path.split("?")[1])
if six.PY3:
query_string = query_string.decode('iso-8859-1')
r['QUERY_STRING'] = query_string
query_string = query_string.decode("iso-8859-1")
r["QUERY_STRING"] = query_string
r.update(extra)
return self.generic('GET', path, **r)
return self.generic("GET", path, **r)
def post(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('POST', path, data, content_type, **extra)
return self.generic("POST", path, data, content_type, **extra)
def put(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PUT', path, data, content_type, **extra)
return self.generic("PUT", path, data, content_type, **extra)
def patch(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('PATCH', path, data, content_type, **extra)
return self.generic("PATCH", path, data, content_type, **extra)
def delete(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('DELETE', path, data, content_type, **extra)
return self.generic("DELETE", path, data, content_type, **extra)
def options(self, path, data=None, format=None, content_type=None, **extra):
data, content_type = self._encode_data(data, format, content_type)
return self.generic('OPTIONS', path, data, content_type, **extra)
return self.generic("OPTIONS", path, data, content_type, **extra)
def generic(self, method, path, data='',
content_type='application/octet-stream', secure=False, **extra):
def generic(
self,
method,
path,
data="",
content_type="application/octet-stream",
secure=False,
**extra
):
# Include the CONTENT_TYPE, regardless of whether or not data is empty.
if content_type is not None:
extra['CONTENT_TYPE'] = str(content_type)
extra["CONTENT_TYPE"] = str(content_type)
return super(APIRequestFactory, self).generic(
method, path, data, content_type, secure, **extra)
method, path, data, content_type, secure, **extra
)
def request(self, **kwargs):
request = super(APIRequestFactory, self).request(**kwargs)
@ -294,42 +318,52 @@ class APIClient(APIRequestFactory, DjangoClient):
response = self._handle_redirects(response, **extra)
return response
def post(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
def post(
self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).post(
path, data=data, format=format, content_type=content_type, **extra)
path, data=data, format=format, content_type=content_type, **extra
)
if follow:
response = self._handle_redirects(response, **extra)
return response
def put(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
def put(
self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).put(
path, data=data, format=format, content_type=content_type, **extra)
path, data=data, format=format, content_type=content_type, **extra
)
if follow:
response = self._handle_redirects(response, **extra)
return response
def patch(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
def patch(
self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).patch(
path, data=data, format=format, content_type=content_type, **extra)
path, data=data, format=format, content_type=content_type, **extra
)
if follow:
response = self._handle_redirects(response, **extra)
return response
def delete(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
def delete(
self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).delete(
path, data=data, format=format, content_type=content_type, **extra)
path, data=data, format=format, content_type=content_type, **extra
)
if follow:
response = self._handle_redirects(response, **extra)
return response
def options(self, path, data=None, format=None, content_type=None,
follow=False, **extra):
def options(
self, path, data=None, format=None, content_type=None, follow=False, **extra
):
response = super(APIClient, self).options(
path, data=data, format=format, content_type=content_type, **extra)
path, data=data, format=format, content_type=content_type, **extra
)
if follow:
response = self._handle_redirects(response, **extra)
return response
@ -377,13 +411,14 @@ class URLPatternsTestCase(testcases.SimpleTestCase):
def test_something_else(self):
...
"""
@classmethod
def setUpClass(cls):
# Get the module of the TestCase subclass
cls._module = import_module(cls.__module__)
cls._override = override_settings(ROOT_URLCONF=cls.__module__)
if hasattr(cls._module, 'urlpatterns'):
if hasattr(cls._module, "urlpatterns"):
cls._module_urlpatterns = cls._module.urlpatterns
cls._module.urlpatterns = cls.urlpatterns
@ -396,7 +431,7 @@ class URLPatternsTestCase(testcases.SimpleTestCase):
super(URLPatternsTestCase, cls).tearDownClass()
cls._override.disable()
if hasattr(cls, '_module_urlpatterns'):
if hasattr(cls, "_module_urlpatterns"):
cls._module.urlpatterns = cls._module_urlpatterns
else:
del cls._module.urlpatterns

View File

@ -20,7 +20,7 @@ class BaseThrottle(object):
"""
Return `True` if the request should be allowed, `False` otherwise.
"""
raise NotImplementedError('.allow_request() must be overridden')
raise NotImplementedError(".allow_request() must be overridden")
def get_ident(self, request):
"""
@ -28,18 +28,18 @@ class BaseThrottle(object):
if present and number of proxies is > 0. If not use all of
HTTP_X_FORWARDED_FOR if it is available, if not use REMOTE_ADDR.
"""
xff = request.META.get('HTTP_X_FORWARDED_FOR')
remote_addr = request.META.get('REMOTE_ADDR')
xff = request.META.get("HTTP_X_FORWARDED_FOR")
remote_addr = request.META.get("REMOTE_ADDR")
num_proxies = api_settings.NUM_PROXIES
if num_proxies is not None:
if num_proxies == 0 or xff is None:
return remote_addr
addrs = xff.split(',')
addrs = xff.split(",")
client_addr = addrs[-min(num_proxies, len(addrs))]
return client_addr.strip()
return ''.join(xff.split()) if xff else remote_addr
return "".join(xff.split()) if xff else remote_addr
def wait(self):
"""
@ -61,14 +61,15 @@ class SimpleRateThrottle(BaseThrottle):
Previous request information used for throttling is stored in the cache.
"""
cache = default_cache
timer = time.time
cache_format = 'throttle_%(scope)s_%(ident)s'
cache_format = "throttle_%(scope)s_%(ident)s"
scope = None
THROTTLE_RATES = api_settings.DEFAULT_THROTTLE_RATES
def __init__(self):
if not getattr(self, 'rate', None):
if not getattr(self, "rate", None):
self.rate = self.get_rate()
self.num_requests, self.duration = self.parse_rate(self.rate)
@ -79,15 +80,17 @@ class SimpleRateThrottle(BaseThrottle):
May return `None` if the request should not be throttled.
"""
raise NotImplementedError('.get_cache_key() must be overridden')
raise NotImplementedError(".get_cache_key() must be overridden")
def get_rate(self):
"""
Determine the string representation of the allowed request rate.
"""
if not getattr(self, 'scope', None):
msg = ("You must set either `.scope` or `.rate` for '%s' throttle" %
self.__class__.__name__)
if not getattr(self, "scope", None):
msg = (
"You must set either `.scope` or `.rate` for '%s' throttle"
% self.__class__.__name__
)
raise ImproperlyConfigured(msg)
try:
@ -103,9 +106,9 @@ class SimpleRateThrottle(BaseThrottle):
"""
if rate is None:
return (None, None)
num, period = rate.split('/')
num, period = rate.split("/")
num_requests = int(num)
duration = {'s': 1, 'm': 60, 'h': 3600, 'd': 86400}[period[0]]
duration = {"s": 1, "m": 60, "h": 3600, "d": 86400}[period[0]]
return (num_requests, duration)
def allow_request(self, request, view):
@ -170,15 +173,16 @@ class AnonRateThrottle(SimpleRateThrottle):
The IP address of the request will be used as the unique cache key.
"""
scope = 'anon'
scope = "anon"
def get_cache_key(self, request, view):
if request.user.is_authenticated:
return None # Only throttle unauthenticated requests.
return self.cache_format % {
'scope': self.scope,
'ident': self.get_ident(request)
"scope": self.scope,
"ident": self.get_ident(request),
}
@ -190,7 +194,8 @@ class UserRateThrottle(SimpleRateThrottle):
authenticated. For anonymous requests, the IP address of the request will
be used.
"""
scope = 'user'
scope = "user"
def get_cache_key(self, request, view):
if request.user.is_authenticated:
@ -198,10 +203,7 @@ class UserRateThrottle(SimpleRateThrottle):
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
return self.cache_format % {"scope": self.scope, "ident": ident}
class ScopedRateThrottle(SimpleRateThrottle):
@ -211,7 +213,8 @@ class ScopedRateThrottle(SimpleRateThrottle):
throttled. The unique cache key will be generated by concatenating the
user id of the request, and the scope of the view being accessed.
"""
scope_attr = 'throttle_scope'
scope_attr = "throttle_scope"
def __init__(self):
# Override the usual SimpleRateThrottle, because we can't determine
@ -246,7 +249,4 @@ class ScopedRateThrottle(SimpleRateThrottle):
else:
ident = self.get_ident(request)
return self.cache_format % {
'scope': self.scope,
'ident': ident
}
return self.cache_format % {"scope": self.scope, "ident": ident}

View File

@ -3,7 +3,11 @@ from __future__ import unicode_literals
from django.conf.urls import include, url
from rest_framework.compat import (
URLResolver, get_regex_pattern, is_route_pattern, path, register_converter
URLResolver,
get_regex_pattern,
is_route_pattern,
path,
register_converter,
)
from rest_framework.settings import api_settings
@ -13,7 +17,7 @@ def _get_format_path_converter(suffix_kwarg, allowed):
if len(allowed) == 1:
allowed_pattern = allowed[0]
else:
allowed_pattern = '(?:%s)' % '|'.join(allowed)
allowed_pattern = "(?:%s)" % "|".join(allowed)
suffix_pattern = r"\.%s/?" % allowed_pattern
else:
suffix_pattern = r"\.[a-z0-9]+/?"
@ -22,19 +26,21 @@ def _get_format_path_converter(suffix_kwarg, allowed):
regex = suffix_pattern
def to_python(self, value):
return value.strip('./')
return value.strip("./")
def to_url(self, value):
return '.' + value + '/'
return "." + value + "/"
converter_name = 'drf_format_suffix'
converter_name = "drf_format_suffix"
if allowed:
converter_name += '_' + '_'.join(allowed)
converter_name += "_" + "_".join(allowed)
return converter_name, FormatSuffixConverter
def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route=None):
def apply_suffix_patterns(
urlpatterns, suffix_pattern, suffix_required, suffix_route=None
):
ret = []
for urlpattern in urlpatterns:
if isinstance(urlpattern, URLResolver):
@ -44,23 +50,28 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
app_name = urlpattern.app_name
kwargs = urlpattern.default_kwargs
# Add in the included patterns, after applying the suffixes
patterns = apply_suffix_patterns(urlpattern.url_patterns,
suffix_pattern,
suffix_required,
suffix_route)
patterns = apply_suffix_patterns(
urlpattern.url_patterns, suffix_pattern, suffix_required, suffix_route
)
# if the original pattern was a RoutePattern we need to preserve it
if is_route_pattern(urlpattern):
assert path is not None
route = str(urlpattern.pattern)
new_pattern = path(route, include((patterns, app_name), namespace), kwargs)
new_pattern = path(
route, include((patterns, app_name), namespace), kwargs
)
else:
new_pattern = url(regex, include((patterns, app_name), namespace), kwargs)
new_pattern = url(
regex, include((patterns, app_name), namespace), kwargs
)
ret.append(new_pattern)
else:
# Regular URL pattern
regex = get_regex_pattern(urlpattern).rstrip('$').rstrip('/') + suffix_pattern
regex = (
get_regex_pattern(urlpattern).rstrip("$").rstrip("/") + suffix_pattern
)
view = urlpattern.callback
kwargs = urlpattern.default_args
name = urlpattern.name
@ -72,7 +83,7 @@ def apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_r
if is_route_pattern(urlpattern):
assert path is not None
assert suffix_route is not None
route = str(urlpattern.pattern).rstrip('$').rstrip('/') + suffix_route
route = str(urlpattern.pattern).rstrip("$").rstrip("/") + suffix_route
new_pattern = path(route, view, kwargs, name)
else:
new_pattern = url(regex, view, kwargs, name)
@ -103,17 +114,21 @@ def format_suffix_patterns(urlpatterns, suffix_required=False, allowed=None):
if len(allowed) == 1:
allowed_pattern = allowed[0]
else:
allowed_pattern = '(%s)' % '|'.join(allowed)
suffix_pattern = r'\.(?P<%s>%s)/?$' % (suffix_kwarg, allowed_pattern)
allowed_pattern = "(%s)" % "|".join(allowed)
suffix_pattern = r"\.(?P<%s>%s)/?$" % (suffix_kwarg, allowed_pattern)
else:
suffix_pattern = r'\.(?P<%s>[a-z0-9]+)/?$' % suffix_kwarg
suffix_pattern = r"\.(?P<%s>[a-z0-9]+)/?$" % suffix_kwarg
if path and register_converter:
converter_name, suffix_converter = _get_format_path_converter(suffix_kwarg, allowed)
converter_name, suffix_converter = _get_format_path_converter(
suffix_kwarg, allowed
)
register_converter(suffix_converter, converter_name)
suffix_route = '<%s:%s>' % (converter_name, suffix_kwarg)
suffix_route = "<%s:%s>" % (converter_name, suffix_kwarg)
else:
suffix_route = None
return apply_suffix_patterns(urlpatterns, suffix_pattern, suffix_required, suffix_route)
return apply_suffix_patterns(
urlpatterns, suffix_pattern, suffix_required, suffix_route
)

View File

@ -16,8 +16,13 @@ from __future__ import unicode_literals
from django.conf.urls import url
from django.contrib.auth import views
app_name = 'rest_framework'
app_name = "rest_framework"
urlpatterns = [
url(r'^login/$', views.LoginView.as_view(template_name='rest_framework/login.html'), name='login'),
url(r'^logout/$', views.LogoutView.as_view(), name='logout'),
url(
r"^login/$",
views.LoginView.as_view(template_name="rest_framework/login.html"),
name="login",
),
url(r"^logout/$", views.LogoutView.as_view(), name="logout"),
]

View File

@ -23,8 +23,8 @@ def get_breadcrumbs(url, request=None):
else:
# Check if this is a REST framework view,
# and if so add it to the breadcrumbs
cls = getattr(view, 'cls', None)
initkwargs = getattr(view, 'initkwargs', {})
cls = getattr(view, "cls", None)
initkwargs = getattr(view, "initkwargs", {})
if cls is not None and issubclass(cls, APIView):
# Don't list the same view twice in a row.
# Probably an optional trailing slash.
@ -35,21 +35,21 @@ def get_breadcrumbs(url, request=None):
breadcrumbs_list.insert(0, (name, insert_url))
seen.append(view)
if url == '':
if url == "":
# All done
return breadcrumbs_list
elif url.endswith('/'):
elif url.endswith("/"):
# Drop trailing slash off the end and continue to try to
# resolve more breadcrumbs
url = url.rstrip('/')
url = url.rstrip("/")
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
# Drop trailing non-slash off the end and continue to try to
# resolve more breadcrumbs
url = url[:url.rfind('/') + 1]
url = url[: url.rfind("/") + 1]
return breadcrumbs_recursive(url, breadcrumbs_list, prefix, seen)
prefix = get_script_prefix().rstrip('/')
url = url[len(prefix):]
prefix = get_script_prefix().rstrip("/")
url = url[len(prefix) :]
return breadcrumbs_recursive(url, [], prefix, [])

View File

@ -21,6 +21,7 @@ class JSONEncoder(json.JSONEncoder):
JSONEncoder subclass that knows how to encode date/time/timedelta,
decimal types, generators and other basic python objects.
"""
def default(self, obj):
# For Date Time string spec, see ECMA 262
# https://ecma-international.org/ecma-262/5.1/#sec-15.9.1.15
@ -28,8 +29,8 @@ class JSONEncoder(json.JSONEncoder):
return force_text(obj)
elif isinstance(obj, datetime.datetime):
representation = obj.isoformat()
if representation.endswith('+00:00'):
representation = representation[:-6] + 'Z'
if representation.endswith("+00:00"):
representation = representation[:-6] + "Z"
return representation
elif isinstance(obj, datetime.date):
return obj.isoformat()
@ -49,20 +50,22 @@ class JSONEncoder(json.JSONEncoder):
return tuple(obj)
elif isinstance(obj, bytes):
# Best-effort for binary blobs. See #4187.
return obj.decode('utf-8')
elif hasattr(obj, 'tolist'):
return obj.decode("utf-8")
elif hasattr(obj, "tolist"):
# Numpy arrays and array scalars.
return obj.tolist()
elif (coreapi is not None) and isinstance(obj, (coreapi.Document, coreapi.Error)):
elif (coreapi is not None) and isinstance(
obj, (coreapi.Document, coreapi.Error)
):
raise RuntimeError(
'Cannot return a coreapi object from a JSON view. '
'You should be using a schema renderer instead for this view.'
"Cannot return a coreapi object from a JSON view. "
"You should be using a schema renderer instead for this view."
)
elif hasattr(obj, '__getitem__'):
elif hasattr(obj, "__getitem__"):
try:
return dict(obj)
except Exception:
pass
elif hasattr(obj, '__iter__'):
elif hasattr(obj, "__iter__"):
return tuple(item for item in obj)
return super(JSONEncoder, self).default(obj)

View File

@ -11,8 +11,12 @@ from django.utils.text import capfirst
from rest_framework.compat import postgres_fields
from rest_framework.validators import UniqueValidator
NUMERIC_FIELD_TYPES = (
models.IntegerField, models.FloatField, models.DecimalField, models.DurationField,
models.IntegerField,
models.FloatField,
models.DecimalField,
models.DurationField,
)
@ -23,11 +27,12 @@ class ClassLookupDict(object):
hierarchy in method resolution order, and returns the first matching value
from the dictionary or raises a KeyError if nothing matches.
"""
def __init__(self, mapping):
self.mapping = mapping
def __getitem__(self, key):
if hasattr(key, '_proxy_class'):
if hasattr(key, "_proxy_class"):
# Deal with proxy classes. Ie. BoundField behaves as if it
# is a Field instance when using ClassLookupDict.
base_class = key._proxy_class
@ -37,7 +42,7 @@ class ClassLookupDict(object):
for cls in inspect.getmro(base_class):
if cls in self.mapping:
return self.mapping[cls]
raise KeyError('Class %s not found in lookup.' % base_class.__name__)
raise KeyError("Class %s not found in lookup." % base_class.__name__)
def __setitem__(self, key, value):
self.mapping[key] = value
@ -48,7 +53,7 @@ def needs_label(model_field, field_name):
Returns `True` if the label based on the model's verbose name
is not equal to the default label it would have based on it's field name.
"""
default_label = field_name.replace('_', ' ').capitalize()
default_label = field_name.replace("_", " ").capitalize()
return capfirst(model_field.verbose_name) != default_label
@ -57,9 +62,9 @@ def get_detail_view_name(model):
Given a model class, return the view name to use for URL relationships
that refer to instances of the model.
"""
return '%(model_name)s-detail' % {
'app_label': model._meta.app_label,
'model_name': model._meta.object_name.lower()
return "%(model_name)s-detail" % {
"app_label": model._meta.app_label,
"model_name": model._meta.object_name.lower(),
}
@ -72,84 +77,98 @@ def get_field_kwargs(field_name, model_field):
# The following will only be used by ModelField classes.
# Gets removed for everything else.
kwargs['model_field'] = model_field
kwargs["model_field"] = model_field
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
kwargs["label"] = capfirst(model_field.verbose_name)
if model_field.help_text:
kwargs['help_text'] = model_field.help_text
kwargs["help_text"] = model_field.help_text
max_digits = getattr(model_field, 'max_digits', None)
max_digits = getattr(model_field, "max_digits", None)
if max_digits is not None:
kwargs['max_digits'] = max_digits
kwargs["max_digits"] = max_digits
decimal_places = getattr(model_field, 'decimal_places', None)
decimal_places = getattr(model_field, "decimal_places", None)
if decimal_places is not None:
kwargs['decimal_places'] = decimal_places
kwargs["decimal_places"] = decimal_places
if isinstance(model_field, models.SlugField):
kwargs['allow_unicode'] = model_field.allow_unicode
kwargs["allow_unicode"] = model_field.allow_unicode
if isinstance(model_field, models.TextField) or (postgres_fields and isinstance(model_field, postgres_fields.JSONField)):
kwargs['style'] = {'base_template': 'textarea.html'}
if isinstance(model_field, models.TextField) or (
postgres_fields and isinstance(model_field, postgres_fields.JSONField)
):
kwargs["style"] = {"base_template": "textarea.html"}
if isinstance(model_field, models.AutoField) or not model_field.editable:
# If this field is read-only, then return early.
# Further keyword arguments are not valid.
kwargs['read_only'] = True
kwargs["read_only"] = True
return kwargs
if model_field.has_default() or model_field.blank or model_field.null:
kwargs['required'] = False
kwargs["required"] = False
if model_field.null and not isinstance(model_field, models.NullBooleanField):
kwargs['allow_null'] = True
kwargs["allow_null"] = True
if model_field.blank and (isinstance(model_field, (models.CharField, models.TextField))):
kwargs['allow_blank'] = True
if model_field.blank and (
isinstance(model_field, (models.CharField, models.TextField))
):
kwargs["allow_blank"] = True
if isinstance(model_field, models.FilePathField):
kwargs['path'] = model_field.path
kwargs["path"] = model_field.path
if model_field.match is not None:
kwargs['match'] = model_field.match
kwargs["match"] = model_field.match
if model_field.recursive is not False:
kwargs['recursive'] = model_field.recursive
kwargs["recursive"] = model_field.recursive
if model_field.allow_files is not True:
kwargs['allow_files'] = model_field.allow_files
kwargs["allow_files"] = model_field.allow_files
if model_field.allow_folders is not False:
kwargs['allow_folders'] = model_field.allow_folders
kwargs["allow_folders"] = model_field.allow_folders
if model_field.choices:
kwargs['choices'] = model_field.choices
kwargs["choices"] = model_field.choices
else:
# Ensure that max_value is passed explicitly as a keyword arg,
# rather than as a validator.
max_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
), None)
max_value = next(
(
validator.limit_value
for validator in validator_kwarg
if isinstance(validator, validators.MaxValueValidator)
),
None,
)
if max_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['max_value'] = max_value
kwargs["max_value"] = max_value
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if not isinstance(validator, validators.MaxValueValidator)
]
# Ensure that min_value is passed explicitly as a keyword arg,
# rather than as a validator.
min_value = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
), None)
min_value = next(
(
validator.limit_value
for validator in validator_kwarg
if isinstance(validator, validators.MinValueValidator)
),
None,
)
if min_value is not None and isinstance(model_field, NUMERIC_FIELD_TYPES):
kwargs['min_value'] = min_value
kwargs["min_value"] = min_value
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if not isinstance(validator, validators.MinValueValidator)
]
@ -157,7 +176,8 @@ def get_field_kwargs(field_name, model_field):
# as it is explicitly added in.
if isinstance(model_field, models.URLField):
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if not isinstance(validator, validators.URLValidator)
]
@ -165,67 +185,79 @@ def get_field_kwargs(field_name, model_field):
# as it is explicitly added in.
if isinstance(model_field, models.EmailField):
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if validator is not validators.validate_email
]
# SlugField do not need to include the 'validate_slug' argument,
if isinstance(model_field, models.SlugField):
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if validator is not validators.validate_slug
]
# IPAddressField do not need to include the 'validate_ipv46_address' argument,
if isinstance(model_field, models.GenericIPAddressField):
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if validator is not validators.validate_ipv46_address
]
# Our decimal validation is handled in the field code, not validator code.
if isinstance(model_field, models.DecimalField):
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if not isinstance(validator, validators.DecimalValidator)
]
# Ensure that max_length is passed explicitly as a keyword arg,
# rather than as a validator.
max_length = getattr(model_field, 'max_length', None)
if max_length is not None and (isinstance(model_field, (models.CharField, models.TextField, models.FileField))):
kwargs['max_length'] = max_length
max_length = getattr(model_field, "max_length", None)
if max_length is not None and (
isinstance(model_field, (models.CharField, models.TextField, models.FileField))
):
kwargs["max_length"] = max_length
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if not isinstance(validator, validators.MaxLengthValidator)
]
# Ensure that min_length is passed explicitly as a keyword arg,
# rather than as a validator.
min_length = next((
validator.limit_value for validator in validator_kwarg
if isinstance(validator, validators.MinLengthValidator)
), None)
min_length = next(
(
validator.limit_value
for validator in validator_kwarg
if isinstance(validator, validators.MinLengthValidator)
),
None,
)
if min_length is not None and isinstance(model_field, models.CharField):
kwargs['min_length'] = min_length
kwargs["min_length"] = min_length
validator_kwarg = [
validator for validator in validator_kwarg
validator
for validator in validator_kwarg
if not isinstance(validator, validators.MinLengthValidator)
]
if getattr(model_field, 'unique', False):
unique_error_message = model_field.error_messages.get('unique', None)
if getattr(model_field, "unique", False):
unique_error_message = model_field.error_messages.get("unique", None)
if unique_error_message:
unique_error_message = unique_error_message % {
'model_name': model_field.model._meta.verbose_name,
'field_label': model_field.verbose_name
"model_name": model_field.model._meta.verbose_name,
"field_label": model_field.verbose_name,
}
validator = UniqueValidator(
queryset=model_field.model._default_manager,
message=unique_error_message)
queryset=model_field.model._default_manager, message=unique_error_message
)
validator_kwarg.append(validator)
if validator_kwarg:
kwargs['validators'] = validator_kwarg
kwargs["validators"] = validator_kwarg
return kwargs
@ -234,65 +266,65 @@ def get_relation_kwargs(field_name, relation_info):
"""
Creates a default instance of a flat relational field.
"""
model_field, related_model, to_many, to_field, has_through_model, reverse = relation_info
model_field, related_model, to_many, to_field, has_through_model, reverse = (
relation_info
)
kwargs = {
'queryset': related_model._default_manager,
'view_name': get_detail_view_name(related_model)
"queryset": related_model._default_manager,
"view_name": get_detail_view_name(related_model),
}
if to_many:
kwargs['many'] = True
kwargs["many"] = True
if to_field:
kwargs['to_field'] = to_field
kwargs["to_field"] = to_field
limit_choices_to = model_field and model_field.get_limit_choices_to()
if limit_choices_to:
if not isinstance(limit_choices_to, models.Q):
limit_choices_to = models.Q(**limit_choices_to)
kwargs['queryset'] = kwargs['queryset'].filter(limit_choices_to)
kwargs["queryset"] = kwargs["queryset"].filter(limit_choices_to)
if has_through_model:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
kwargs["read_only"] = True
kwargs.pop("queryset", None)
if model_field:
if model_field.verbose_name and needs_label(model_field, field_name):
kwargs['label'] = capfirst(model_field.verbose_name)
kwargs["label"] = capfirst(model_field.verbose_name)
help_text = model_field.help_text
if help_text:
kwargs['help_text'] = help_text
kwargs["help_text"] = help_text
if not model_field.editable:
kwargs['read_only'] = True
kwargs.pop('queryset', None)
if kwargs.get('read_only', False):
kwargs["read_only"] = True
kwargs.pop("queryset", None)
if kwargs.get("read_only", False):
# If this field is read-only, then return early.
# No further keyword arguments are valid.
return kwargs
if model_field.has_default() or model_field.blank or model_field.null:
kwargs['required'] = False
kwargs["required"] = False
if model_field.null:
kwargs['allow_null'] = True
kwargs["allow_null"] = True
if model_field.validators:
kwargs['validators'] = model_field.validators
if getattr(model_field, 'unique', False):
kwargs["validators"] = model_field.validators
if getattr(model_field, "unique", False):
validator = UniqueValidator(queryset=model_field.model._default_manager)
kwargs['validators'] = kwargs.get('validators', []) + [validator]
kwargs["validators"] = kwargs.get("validators", []) + [validator]
if to_many and not model_field.blank:
kwargs['allow_empty'] = False
kwargs["allow_empty"] = False
return kwargs
def get_nested_relation_kwargs(relation_info):
kwargs = {'read_only': True}
kwargs = {"read_only": True}
if relation_info.to_many:
kwargs['many'] = True
kwargs["many"] = True
return kwargs
def get_url_kwargs(model_field):
return {
'view_name': get_detail_view_name(model_field)
}
return {"view_name": get_detail_view_name(model_field)}

View File

@ -18,7 +18,7 @@ def remove_trailing_string(content, trailing):
Used when generating names from view classes.
"""
if content.endswith(trailing) and content != trailing:
return content[:-len(trailing)]
return content[: -len(trailing)]
return content
@ -36,14 +36,14 @@ def dedent(content):
# unindent the content if needed
if lines:
whitespace_counts = min([len(line) - len(line.lstrip(' ')) for line in lines])
tab_counts = min([len(line) - len(line.lstrip('\t')) for line in lines])
whitespace_counts = min([len(line) - len(line.lstrip(" ")) for line in lines])
tab_counts = min([len(line) - len(line.lstrip("\t")) for line in lines])
if whitespace_counts:
whitespace_pattern = '^' + (' ' * whitespace_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
whitespace_pattern = "^" + (" " * whitespace_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content)
elif tab_counts:
whitespace_pattern = '^' + ('\t' * tab_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), '', content)
whitespace_pattern = "^" + ("\t" * tab_counts)
content = re.sub(re.compile(whitespace_pattern, re.MULTILINE), "", content)
return content.strip()
@ -52,9 +52,9 @@ def camelcase_to_spaces(content):
Translate 'CamelCaseNames' to 'Camel Case Names'.
Used when generating names from view classes.
"""
camelcase_boundary = '(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))'
content = re.sub(camelcase_boundary, ' \\1', content).strip()
return ' '.join(content.split('_')).title()
camelcase_boundary = "(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))"
content = re.sub(camelcase_boundary, " \\1", content).strip()
return " ".join(content.split("_")).title()
def markup_description(description):
@ -64,6 +64,6 @@ def markup_description(description):
if apply_markdown:
description = apply_markdown(description)
else:
description = escape(description).replace('\n', '<br />')
description = '<p>' + description + '</p>'
description = escape(description).replace("\n", "<br />")
description = "<p>" + description + "</p>"
return mark_safe(description)

View File

@ -9,10 +9,10 @@ from django.utils.datastructures import MultiValueDict
def is_html_input(dictionary):
# MultiDict type datastructures are used to represent HTML form input,
# which may have more than one value for each key.
return hasattr(dictionary, 'getlist')
return hasattr(dictionary, "getlist")
def parse_html_list(dictionary, prefix='', default=None):
def parse_html_list(dictionary, prefix="", default=None):
"""
Used to support list values in HTML forms.
Supports lists of primitives and/or dictionaries.
@ -48,7 +48,7 @@ def parse_html_list(dictionary, prefix='', default=None):
:returns a list of objects, or the value specified in ``default`` if the list is empty
"""
ret = {}
regex = re.compile(r'^%s\[([0-9]+)\](.*)$' % re.escape(prefix))
regex = re.compile(r"^%s\[([0-9]+)\](.*)$" % re.escape(prefix))
for field, value in dictionary.items():
match = regex.match(field)
if not match:
@ -66,7 +66,7 @@ def parse_html_list(dictionary, prefix='', default=None):
return [ret[item] for item in sorted(ret)] if ret else default
def parse_html_dict(dictionary, prefix=''):
def parse_html_dict(dictionary, prefix=""):
"""
Used to support dictionary values in HTML forms.
@ -83,7 +83,7 @@ def parse_html_dict(dictionary, prefix=''):
}
"""
ret = MultiValueDict()
regex = re.compile(r'^%s\.(.+)$' % re.escape(prefix))
regex = re.compile(r"^%s\.(.+)$" % re.escape(prefix))
for field in dictionary:
match = regex.match(field)
if not match:

View File

@ -5,20 +5,19 @@ from rest_framework import ISO_8601
def datetime_formats(formats):
format = ', '.join(formats).replace(
ISO_8601,
'YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]'
format = ", ".join(formats).replace(
ISO_8601, "YYYY-MM-DDThh:mm[:ss[.uuuuuu]][+HH:MM|-HH:MM|Z]"
)
return humanize_strptime(format)
def date_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'YYYY-MM-DD')
format = ", ".join(formats).replace(ISO_8601, "YYYY-MM-DD")
return humanize_strptime(format)
def time_formats(formats):
format = ', '.join(formats).replace(ISO_8601, 'hh:mm[:ss[.uuuuuu]]')
format = ", ".join(formats).replace(ISO_8601, "hh:mm[:ss[.uuuuuu]]")
return humanize_strptime(format)
@ -40,7 +39,7 @@ def humanize_strptime(format_string):
"%a": "[Mon-Sun]",
"%A": "[Monday-Sunday]",
"%p": "[AM|PM]",
"%z": "[+HHMM|-HHMM]"
"%z": "[+HHMM|-HHMM]",
}
for key, val in mapping.items():
format_string = format_string.replace(key, val)

View File

@ -13,28 +13,28 @@ import json # noqa
def strict_constant(o):
raise ValueError('Out of range float values are not JSON compliant: ' + repr(o))
raise ValueError("Out of range float values are not JSON compliant: " + repr(o))
@functools.wraps(json.dump)
def dump(*args, **kwargs):
kwargs.setdefault('allow_nan', False)
kwargs.setdefault("allow_nan", False)
return json.dump(*args, **kwargs)
@functools.wraps(json.dumps)
def dumps(*args, **kwargs):
kwargs.setdefault('allow_nan', False)
kwargs.setdefault("allow_nan", False)
return json.dumps(*args, **kwargs)
@functools.wraps(json.load)
def load(*args, **kwargs):
kwargs.setdefault('parse_constant', strict_constant)
kwargs.setdefault("parse_constant", strict_constant)
return json.load(*args, **kwargs)
@functools.wraps(json.loads)
def loads(*args, **kwargs):
kwargs.setdefault('parse_constant', strict_constant)
kwargs.setdefault("parse_constant", strict_constant)
return json.loads(*args, **kwargs)

View File

@ -49,20 +49,30 @@ def order_by_precedence(media_type_lst):
@python_2_unicode_compatible
class _MediaType(object):
def __init__(self, media_type_str):
self.orig = '' if (media_type_str is None) else media_type_str
self.full_type, self.params = parse_header(self.orig.encode(HTTP_HEADER_ENCODING))
self.main_type, sep, self.sub_type = self.full_type.partition('/')
self.orig = "" if (media_type_str is None) else media_type_str
self.full_type, self.params = parse_header(
self.orig.encode(HTTP_HEADER_ENCODING)
)
self.main_type, sep, self.sub_type = self.full_type.partition("/")
def match(self, other):
"""Return true if this MediaType satisfies the given MediaType."""
for key in self.params:
if key != 'q' and other.params.get(key, None) != self.params.get(key, None):
if key != "q" and other.params.get(key, None) != self.params.get(key, None):
return False
if self.sub_type != '*' and other.sub_type != '*' and other.sub_type != self.sub_type:
if (
self.sub_type != "*"
and other.sub_type != "*"
and other.sub_type != self.sub_type
):
return False
if self.main_type != '*' and other.main_type != '*' and other.main_type != self.main_type:
if (
self.main_type != "*"
and other.main_type != "*"
and other.main_type != self.main_type
):
return False
return True
@ -72,16 +82,16 @@ class _MediaType(object):
"""
Return a precedence level from 0-3 for the media type given how specific it is.
"""
if self.main_type == '*':
if self.main_type == "*":
return 0
elif self.sub_type == '*':
elif self.sub_type == "*":
return 1
elif not self.params or list(self.params) == ['q']:
elif not self.params or list(self.params) == ["q"]:
return 2
return 3
def __str__(self):
ret = "%s/%s" % (self.main_type, self.sub_type)
for key, val in self.params.items():
ret += "; %s=%s" % (key, val.decode('ascii'))
ret += "; %s=%s" % (key, val.decode("ascii"))
return ret

View File

@ -7,23 +7,30 @@ Usage: `get_field_info(model)` returns a `FieldInfo` instance.
"""
from collections import OrderedDict, namedtuple
FieldInfo = namedtuple('FieldResult', [
'pk', # Model field instance
'fields', # Dict of field name -> model field instance
'forward_relations', # Dict of field name -> RelationInfo
'reverse_relations', # Dict of field name -> RelationInfo
'fields_and_pk', # Shortcut for 'pk' + 'fields'
'relations' # Shortcut for 'forward_relations' + 'reverse_relations'
])
RelationInfo = namedtuple('RelationInfo', [
'model_field',
'related_model',
'to_many',
'to_field',
'has_through_model',
'reverse'
])
FieldInfo = namedtuple(
"FieldResult",
[
"pk", # Model field instance
"fields", # Dict of field name -> model field instance
"forward_relations", # Dict of field name -> RelationInfo
"reverse_relations", # Dict of field name -> RelationInfo
"fields_and_pk", # Shortcut for 'pk' + 'fields'
"relations", # Shortcut for 'forward_relations' + 'reverse_relations'
],
)
RelationInfo = namedtuple(
"RelationInfo",
[
"model_field",
"related_model",
"to_many",
"to_field",
"has_through_model",
"reverse",
],
)
def get_field_info(model):
@ -41,8 +48,9 @@ def get_field_info(model):
fields_and_pk = _merge_fields_and_pk(pk, fields)
relationships = _merge_relationships(forward_relations, reverse_relations)
return FieldInfo(pk, fields, forward_relations, reverse_relations,
fields_and_pk, relationships)
return FieldInfo(
pk, fields, forward_relations, reverse_relations, fields_and_pk, relationships
)
def _get_pk(opts):
@ -59,14 +67,16 @@ def _get_pk(opts):
def _get_fields(opts):
fields = OrderedDict()
for field in [field for field in opts.fields if field.serialize and not field.remote_field]:
for field in [
field for field in opts.fields if field.serialize and not field.remote_field
]:
fields[field.name] = field
return fields
def _get_to_field(field):
return getattr(field, 'to_fields', None) and field.to_fields[0]
return getattr(field, "to_fields", None) and field.to_fields[0]
def _get_forward_relationships(opts):
@ -74,14 +84,16 @@ def _get_forward_relationships(opts):
Returns an `OrderedDict` of field names to `RelationInfo`.
"""
forward_relations = OrderedDict()
for field in [field for field in opts.fields if field.serialize and field.remote_field]:
for field in [
field for field in opts.fields if field.serialize and field.remote_field
]:
forward_relations[field.name] = RelationInfo(
model_field=field,
related_model=field.remote_field.model,
to_many=False,
to_field=_get_to_field(field),
has_through_model=False,
reverse=False
reverse=False,
)
# Deal with forward many-to-many relationships.
@ -92,10 +104,8 @@ def _get_forward_relationships(opts):
to_many=True,
# manytomany do not have to_fields
to_field=None,
has_through_model=(
not field.remote_field.through._meta.auto_created
),
reverse=False
has_through_model=(not field.remote_field.through._meta.auto_created),
reverse=False,
)
return forward_relations
@ -115,11 +125,13 @@ def _get_reverse_relationships(opts):
to_many=relation.field.remote_field.multiple,
to_field=_get_to_field(relation.field),
has_through_model=False,
reverse=True
reverse=True,
)
# Deal with reverse many-to-many relationships.
all_related_many_to_many_objects = [r for r in opts.related_objects if r.field.many_to_many]
all_related_many_to_many_objects = [
r for r in opts.related_objects if r.field.many_to_many
]
for relation in all_related_many_to_many_objects:
accessor_name = relation.get_accessor_name()
reverse_relations[accessor_name] = RelationInfo(
@ -129,10 +141,10 @@ def _get_reverse_relationships(opts):
# manytomany do not have to_fields
to_field=None,
has_through_model=(
(getattr(relation.field.remote_field, 'through', None) is not None) and
not relation.field.remote_field.through._meta.auto_created
(getattr(relation.field.remote_field, "through", None) is not None)
and not relation.field.remote_field.through._meta.auto_created
),
reverse=True
reverse=True,
)
return reverse_relations
@ -140,7 +152,7 @@ def _get_reverse_relationships(opts):
def _merge_fields_and_pk(pk, fields):
fields_and_pk = OrderedDict()
fields_and_pk['pk'] = pk
fields_and_pk["pk"] = pk
fields_and_pk[pk.name] = pk
fields_and_pk.update(fields)
@ -149,8 +161,7 @@ def _merge_fields_and_pk(pk, fields):
def _merge_relationships(forward_relations, reverse_relations):
return OrderedDict(
list(forward_relations.items()) +
list(reverse_relations.items())
list(forward_relations.items()) + list(reverse_relations.items())
)
@ -158,4 +169,8 @@ def is_abstract_model(model):
"""
Given a model class, returns a boolean True if it is abstract and False if it is not.
"""
return hasattr(model, '_meta') and hasattr(model._meta, 'abstract') and model._meta.abstract
return (
hasattr(model, "_meta")
and hasattr(model._meta, "abstract")
and model._meta.abstract
)

View File

@ -16,14 +16,10 @@ from rest_framework.compat import unicode_repr
def manager_repr(value):
model = value.model
opts = model._meta
names_and_managers = [
(manager.name, manager)
for manager
in opts.managers
]
names_and_managers = [(manager.name, manager) for manager in opts.managers]
for manager_name, manager_instance in names_and_managers:
if manager_instance == value:
return '%s.%s.all()' % (model._meta.object_name, manager_name)
return "%s.%s.all()" % (model._meta.object_name, manager_name)
return repr(value)
@ -45,7 +41,7 @@ def smart_repr(value):
# <django.core.validators.RegexValidator object at 0x1047af050>
# Should be presented as
# <django.core.validators.RegexValidator object>
value = re.sub(' at 0x[0-9A-Fa-f]{4,32}>', '>', value)
value = re.sub(" at 0x[0-9A-Fa-f]{4,32}>", ">", value)
return value
@ -54,16 +50,15 @@ def field_repr(field, force_many=False):
kwargs = field._kwargs
if force_many:
kwargs = kwargs.copy()
kwargs['many'] = True
kwargs.pop('child', None)
kwargs["many"] = True
kwargs.pop("child", None)
arg_string = ', '.join([smart_repr(val) for val in field._args])
kwarg_string = ', '.join([
'%s=%s' % (key, smart_repr(val))
for key, val in sorted(kwargs.items())
])
arg_string = ", ".join([smart_repr(val) for val in field._args])
kwarg_string = ", ".join(
["%s=%s" % (key, smart_repr(val)) for key, val in sorted(kwargs.items())]
)
if arg_string and kwarg_string:
arg_string += ', '
arg_string += ", "
if force_many:
class_name = force_many.__class__.__name__
@ -74,8 +69,8 @@ def field_repr(field, force_many=False):
def serializer_repr(serializer, indent, force_many=None):
ret = field_repr(serializer, force_many) + ':'
indent_str = ' ' * indent
ret = field_repr(serializer, force_many) + ":"
indent_str = " " * indent
if force_many:
fields = force_many.fields
@ -83,25 +78,27 @@ def serializer_repr(serializer, indent, force_many=None):
fields = serializer.fields
for field_name, field in fields.items():
ret += '\n' + indent_str + field_name + ' = '
if hasattr(field, 'fields'):
ret += "\n" + indent_str + field_name + " = "
if hasattr(field, "fields"):
ret += serializer_repr(field, indent + 1)
elif hasattr(field, 'child'):
elif hasattr(field, "child"):
ret += list_repr(field, indent + 1)
elif hasattr(field, 'child_relation'):
elif hasattr(field, "child_relation"):
ret += field_repr(field.child_relation, force_many=field.child_relation)
else:
ret += field_repr(field)
if serializer.validators:
ret += '\n' + indent_str + 'class Meta:'
ret += '\n' + indent_str + ' validators = ' + smart_repr(serializer.validators)
ret += "\n" + indent_str + "class Meta:"
ret += (
"\n" + indent_str + " validators = " + smart_repr(serializer.validators)
)
return ret
def list_repr(serializer, indent):
child = serializer.child
if hasattr(child, 'fields'):
if hasattr(child, "fields"):
return serializer_repr(serializer, indent, force_many=child)
return field_repr(serializer)

View File

@ -16,7 +16,7 @@ class ReturnDict(OrderedDict):
"""
def __init__(self, *args, **kwargs):
self.serializer = kwargs.pop('serializer')
self.serializer = kwargs.pop("serializer")
super(ReturnDict, self).__init__(*args, **kwargs)
def copy(self):
@ -39,7 +39,7 @@ class ReturnList(list):
"""
def __init__(self, *args, **kwargs):
self.serializer = kwargs.pop('serializer')
self.serializer = kwargs.pop("serializer")
super(ReturnList, self).__init__(*args, **kwargs)
def __repr__(self):
@ -58,7 +58,7 @@ class BoundField(object):
providing an API similar to Django forms and form fields.
"""
def __init__(self, field, value, errors, prefix=''):
def __init__(self, field, value, errors, prefix=""):
self._field = field
self._prefix = prefix
self.value = value
@ -73,12 +73,13 @@ class BoundField(object):
return self._field.__class__
def __repr__(self):
return unicode_to_repr('<%s value=%s errors=%s>' % (
self.__class__.__name__, self.value, self.errors
))
return unicode_to_repr(
"<%s value=%s errors=%s>"
% (self.__class__.__name__, self.value, self.errors)
)
def as_form_field(self):
value = '' if (self.value is None or self.value is False) else self.value
value = "" if (self.value is None or self.value is False) else self.value
return self.__class__(self._field, value, self.errors, self._prefix)
@ -87,7 +88,7 @@ class JSONBoundField(BoundField):
value = self.value
# When HTML form input is used and the input is not valid
# value will be a JSONString, rather than a JSON primitive.
if not getattr(value, 'is_json_string', False):
if not getattr(value, "is_json_string", False):
try:
value = json.dumps(self.value, sort_keys=True, indent=4)
except (TypeError, ValueError):
@ -102,8 +103,8 @@ class NestedBoundField(BoundField):
`BoundField` that is used for serializer fields.
"""
def __init__(self, field, value, errors, prefix=''):
if value is None or value is '':
def __init__(self, field, value, errors, prefix=""):
if value is None or value is "":
value = {}
super(NestedBoundField, self).__init__(field, value, errors, prefix)
@ -115,9 +116,9 @@ class NestedBoundField(BoundField):
field = self.fields[key]
value = self.value.get(key) if self.value else None
error = self.errors.get(key) if isinstance(self.errors, dict) else None
if hasattr(field, 'fields'):
return NestedBoundField(field, value, error, prefix=self.name + '.')
return BoundField(field, value, error, prefix=self.name + '.')
if hasattr(field, "fields"):
return NestedBoundField(field, value, error, prefix=self.name + ".")
return BoundField(field, value, error, prefix=self.name + ".")
def as_form_field(self):
values = {}
@ -125,7 +126,9 @@ class NestedBoundField(BoundField):
if isinstance(value, (list, dict)):
values[key] = value
else:
values[key] = '' if (value is None or value is False) else force_text(value)
values[key] = (
"" if (value is None or value is False) else force_text(value)
)
return self.__class__(self._field, values, self.errors, self._prefix)

View File

@ -39,9 +39,10 @@ class UniqueValidator(object):
Should be applied to an individual field on the serializer.
"""
message = _('This field must be unique.')
def __init__(self, queryset, message=None, lookup='exact'):
message = _("This field must be unique.")
def __init__(self, queryset, message=None, lookup="exact"):
self.queryset = queryset
self.serializer_field = None
self.message = message or self.message
@ -56,13 +57,13 @@ class UniqueValidator(object):
# same as the serializer field name if `source=<>` is set.
self.field_name = serializer_field.source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer_field.parent, 'instance', None)
self.instance = getattr(serializer_field.parent, "instance", None)
def filter_queryset(self, value, queryset):
"""
Filter the queryset to all instances matching the given attribute.
"""
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value}
filter_kwargs = {"%s__%s" % (self.field_name, self.lookup): value}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, queryset):
@ -79,13 +80,12 @@ class UniqueValidator(object):
queryset = self.filter_queryset(value, queryset)
queryset = self.exclude_current_instance(queryset)
if qs_exists(queryset):
raise ValidationError(self.message, code='unique')
raise ValidationError(self.message, code="unique")
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset)
))
return unicode_to_repr(
"<%s(queryset=%s)>" % (self.__class__.__name__, smart_repr(self.queryset))
)
class UniqueTogetherValidator(object):
@ -94,8 +94,9 @@ class UniqueTogetherValidator(object):
Should be applied to the serializer class, not to an individual field.
"""
message = _('The fields {field_names} must make a unique set.')
missing_message = _('This field is required.')
message = _("The fields {field_names} must make a unique set.")
missing_message = _("This field is required.")
def __init__(self, queryset, fields, message=None):
self.queryset = queryset
@ -109,7 +110,7 @@ class UniqueTogetherValidator(object):
prior to the validation call being made.
"""
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
self.instance = getattr(serializer, "instance", None)
def enforce_required_fields(self, attrs):
"""
@ -125,7 +126,7 @@ class UniqueTogetherValidator(object):
if field_name not in attrs
}
if missing_items:
raise ValidationError(missing_items, code='required')
raise ValidationError(missing_items, code="required")
def filter_queryset(self, attrs, queryset):
"""
@ -139,10 +140,7 @@ class UniqueTogetherValidator(object):
attrs[field_name] = getattr(self.instance, field_name)
# Determine the filter keyword arguments and filter the queryset.
filter_kwargs = {
field_name: attrs[field_name]
for field_name in self.fields
}
filter_kwargs = {field_name: attrs[field_name] for field_name in self.fields}
return qs_filter(queryset, **filter_kwargs)
def exclude_current_instance(self, attrs, queryset):
@ -165,21 +163,24 @@ class UniqueTogetherValidator(object):
value for field, value in attrs.items() if field in self.fields
]
if None not in checked_values and qs_exists(queryset):
field_names = ', '.join(self.fields)
field_names = ", ".join(self.fields)
message = self.message.format(field_names=field_names)
raise ValidationError(message, code='unique')
raise ValidationError(message, code="unique")
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, fields=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.fields)
))
return unicode_to_repr(
"<%s(queryset=%s, fields=%s)>"
% (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.fields),
)
)
class BaseUniqueForValidator(object):
message = None
missing_message = _('This field is required.')
missing_message = _("This field is required.")
def __init__(self, queryset, field, date_field, message=None):
self.queryset = queryset
@ -197,7 +198,7 @@ class BaseUniqueForValidator(object):
self.field_name = serializer.fields[self.field].source_attrs[-1]
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
# Determine the existing instance, if this is an update operation.
self.instance = getattr(serializer, 'instance', None)
self.instance = getattr(serializer, "instance", None)
def enforce_required_fields(self, attrs):
"""
@ -210,10 +211,10 @@ class BaseUniqueForValidator(object):
if field_name not in attrs
}
if missing_items:
raise ValidationError(missing_items, code='required')
raise ValidationError(missing_items, code="required")
def filter_queryset(self, attrs, queryset):
raise NotImplementedError('`filter_queryset` must be implemented.')
raise NotImplementedError("`filter_queryset` must be implemented.")
def exclude_current_instance(self, attrs, queryset):
"""
@ -231,17 +232,18 @@ class BaseUniqueForValidator(object):
queryset = self.exclude_current_instance(attrs, queryset)
if qs_exists(queryset):
message = self.message.format(date_field=self.date_field)
raise ValidationError({
self.field: message
}, code='unique')
raise ValidationError({self.field: message}, code="unique")
def __repr__(self):
return unicode_to_repr('<%s(queryset=%s, field=%s, date_field=%s)>' % (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.field),
smart_repr(self.date_field)
))
return unicode_to_repr(
"<%s(queryset=%s, field=%s, date_field=%s)>"
% (
self.__class__.__name__,
smart_repr(self.queryset),
smart_repr(self.field),
smart_repr(self.date_field),
)
)
class UniqueForDateValidator(BaseUniqueForValidator):
@ -253,9 +255,9 @@ class UniqueForDateValidator(BaseUniqueForValidator):
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__day' % self.date_field_name] = date.day
filter_kwargs['%s__month' % self.date_field_name] = date.month
filter_kwargs['%s__year' % self.date_field_name] = date.year
filter_kwargs["%s__day" % self.date_field_name] = date.day
filter_kwargs["%s__month" % self.date_field_name] = date.month
filter_kwargs["%s__year" % self.date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)
@ -268,7 +270,7 @@ class UniqueForMonthValidator(BaseUniqueForValidator):
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__month' % self.date_field_name] = date.month
filter_kwargs["%s__month" % self.date_field_name] = date.month
return qs_filter(queryset, **filter_kwargs)
@ -281,5 +283,5 @@ class UniqueForYearValidator(BaseUniqueForValidator):
filter_kwargs = {}
filter_kwargs[self.field_name] = value
filter_kwargs['%s__year' % self.date_field_name] = date.year
filter_kwargs["%s__year" % self.date_field_name] = date.year
return qs_filter(queryset, **filter_kwargs)

View File

@ -19,19 +19,20 @@ class BaseVersioning(object):
version_param = api_settings.VERSION_PARAM
def determine_version(self, request, *args, **kwargs):
msg = '{cls}.determine_version() must be implemented.'
raise NotImplementedError(msg.format(
cls=self.__class__.__name__
))
msg = "{cls}.determine_version() must be implemented."
raise NotImplementedError(msg.format(cls=self.__class__.__name__))
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
return _reverse(viewname, args, kwargs, request, format, **extra)
def is_allowed_version(self, version):
if not self.allowed_versions:
return True
return ((version is not None and version == self.default_version) or
(version in self.allowed_versions))
return (version is not None and version == self.default_version) or (
version in self.allowed_versions
)
class AcceptHeaderVersioning(BaseVersioning):
@ -40,6 +41,7 @@ class AcceptHeaderVersioning(BaseVersioning):
Host: example.com
Accept: application/json; version=1.0
"""
invalid_version_message = _('Invalid version in "Accept" header.')
def determine_version(self, request, *args, **kwargs):
@ -71,7 +73,8 @@ class URLPathVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
invalid_version_message = _('Invalid version in URL path.')
invalid_version_message = _("Invalid version in URL path.")
def determine_version(self, request, *args, **kwargs):
version = kwargs.get(self.version_param, self.default_version)
@ -82,7 +85,9 @@ class URLPathVersioning(BaseVersioning):
raise exceptions.NotFound(self.invalid_version_message)
return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
if request.version is not None:
kwargs = {} if (kwargs is None) else kwargs
kwargs[self.version_param] = request.version
@ -116,21 +121,26 @@ class NamespaceVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
invalid_version_message = _('Invalid version in URL path. Does not match any version namespace.')
invalid_version_message = _(
"Invalid version in URL path. Does not match any version namespace."
)
def determine_version(self, request, *args, **kwargs):
resolver_match = getattr(request, 'resolver_match', None)
resolver_match = getattr(request, "resolver_match", None)
if resolver_match is None or not resolver_match.namespace:
return self.default_version
# Allow for possibly nested namespaces.
possible_versions = resolver_match.namespace.split(':')
possible_versions = resolver_match.namespace.split(":")
for version in possible_versions:
if self.is_allowed_version(version):
return version
raise exceptions.NotFound(self.invalid_version_message)
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
if request.version is not None:
viewname = self.get_versioned_viewname(viewname, request)
return super(NamespaceVersioning, self).reverse(
@ -138,7 +148,7 @@ class NamespaceVersioning(BaseVersioning):
)
def get_versioned_viewname(self, viewname, request):
return request.version + ':' + viewname
return request.version + ":" + viewname
class HostNameVersioning(BaseVersioning):
@ -147,11 +157,12 @@ class HostNameVersioning(BaseVersioning):
Host: v1.example.com
Accept: application/json
"""
hostname_regex = re.compile(r'^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$')
invalid_version_message = _('Invalid version in hostname.')
hostname_regex = re.compile(r"^([a-zA-Z0-9]+)\.[a-zA-Z0-9]+\.[a-zA-Z0-9]+$")
invalid_version_message = _("Invalid version in hostname.")
def determine_version(self, request, *args, **kwargs):
hostname, separator, port = request.get_host().partition(':')
hostname, separator, port = request.get_host().partition(":")
match = self.hostname_regex.match(hostname)
if not match:
return self.default_version
@ -170,7 +181,8 @@ class QueryParameterVersioning(BaseVersioning):
Host: example.com
Accept: application/json
"""
invalid_version_message = _('Invalid version in query parameter.')
invalid_version_message = _("Invalid version in query parameter.")
def determine_version(self, request, *args, **kwargs):
version = request.query_params.get(self.version_param, self.default_version)
@ -178,7 +190,9 @@ class QueryParameterVersioning(BaseVersioning):
raise exceptions.NotFound(self.invalid_version_message)
return version
def reverse(self, viewname, args=None, kwargs=None, request=None, format=None, **extra):
def reverse(
self, viewname, args=None, kwargs=None, request=None, format=None, **extra
):
url = super(QueryParameterVersioning, self).reverse(
viewname, args, kwargs, request, format, **extra
)

View File

@ -29,19 +29,19 @@ def get_view_name(view):
This function is the default for the `VIEW_NAME_FUNCTION` setting.
"""
# Name may be set by some Views, such as a ViewSet.
name = getattr(view, 'name', None)
name = getattr(view, "name", None)
if name is not None:
return name
name = view.__class__.__name__
name = formatting.remove_trailing_string(name, 'View')
name = formatting.remove_trailing_string(name, 'ViewSet')
name = formatting.remove_trailing_string(name, "View")
name = formatting.remove_trailing_string(name, "ViewSet")
name = formatting.camelcase_to_spaces(name)
# Suffix may be set by some Views, such as a ViewSet.
suffix = getattr(view, 'suffix', None)
suffix = getattr(view, "suffix", None)
if suffix:
name += ' ' + suffix
name += " " + suffix
return name
@ -54,9 +54,9 @@ def get_view_description(view, html=False):
This function is the default for the `VIEW_DESCRIPTION_FUNCTION` setting.
"""
# Description may be set by some Views, such as a ViewSet.
description = getattr(view, 'description', None)
description = getattr(view, "description", None)
if description is None:
description = view.__class__.__doc__ or ''
description = view.__class__.__doc__ or ""
description = formatting.dedent(smart_text(description))
if html:
@ -65,7 +65,7 @@ def get_view_description(view, html=False):
def set_rollback():
atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False)
atomic_requests = connection.settings_dict.get("ATOMIC_REQUESTS", False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)
@ -87,15 +87,15 @@ def exception_handler(exc, context):
if isinstance(exc, exceptions.APIException):
headers = {}
if getattr(exc, 'auth_header', None):
headers['WWW-Authenticate'] = exc.auth_header
if getattr(exc, 'wait', None):
headers['Retry-After'] = '%d' % exc.wait
if getattr(exc, "auth_header", None):
headers["WWW-Authenticate"] = exc.auth_header
if getattr(exc, "wait", None):
headers["Retry-After"] = "%d" % exc.wait
if isinstance(exc.detail, (list, dict)):
data = exc.detail
else:
data = {'detail': exc.detail}
data = {"detail": exc.detail}
set_rollback()
return Response(data, status=exc.status_code, headers=headers)
@ -128,13 +128,15 @@ class APIView(View):
This allows us to discover information about the view when we do URL
reverse lookups. Used for breadcrumb generation.
"""
if isinstance(getattr(cls, 'queryset', None), models.query.QuerySet):
if isinstance(getattr(cls, "queryset", None), models.query.QuerySet):
def force_evaluation():
raise RuntimeError(
'Do not evaluate the `.queryset` attribute directly, '
'as the result will be cached and reused between requests. '
'Use `.all()` or call `.get_queryset()` instead.'
"Do not evaluate the `.queryset` attribute directly, "
"as the result will be cached and reused between requests. "
"Use `.all()` or call `.get_queryset()` instead."
)
cls.queryset._fetch_all = force_evaluation
view = super(APIView, cls).as_view(**initkwargs)
@ -154,11 +156,9 @@ class APIView(View):
@property
def default_response_headers(self):
headers = {
'Allow': ', '.join(self.allowed_methods),
}
headers = {"Allow": ", ".join(self.allowed_methods)}
if len(self.renderer_classes) > 1:
headers['Vary'] = 'Accept'
headers["Vary"] = "Accept"
return headers
def http_method_not_allowed(self, request, *args, **kwargs):
@ -199,9 +199,9 @@ class APIView(View):
# Note: Additionally `request` and `encoding` will also be added
# to the context by the Request object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {})
"view": self,
"args": getattr(self, "args", ()),
"kwargs": getattr(self, "kwargs", {}),
}
def get_renderer_context(self):
@ -212,10 +212,10 @@ class APIView(View):
# Note: Additionally 'response' will also be added to the context,
# by the Response object.
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
"view": self,
"args": getattr(self, "args", ()),
"kwargs": getattr(self, "kwargs", {}),
"request": getattr(self, "request", None),
}
def get_exception_handler_context(self):
@ -224,10 +224,10 @@ class APIView(View):
as the `context` argument.
"""
return {
'view': self,
'args': getattr(self, 'args', ()),
'kwargs': getattr(self, 'kwargs', {}),
'request': getattr(self, 'request', None)
"view": self,
"args": getattr(self, "args", ()),
"kwargs": getattr(self, "kwargs", {}),
"request": getattr(self, "request", None),
}
def get_view_name(self):
@ -289,7 +289,7 @@ class APIView(View):
"""
Instantiate and return the content negotiation class to use.
"""
if not getattr(self, '_negotiator', None):
if not getattr(self, "_negotiator", None):
self._negotiator = self.content_negotiation_class()
return self._negotiator
@ -333,7 +333,7 @@ class APIView(View):
for permission in self.get_permissions():
if not permission.has_permission(request, self):
self.permission_denied(
request, message=getattr(permission, 'message', None)
request, message=getattr(permission, "message", None)
)
def check_object_permissions(self, request, obj):
@ -344,7 +344,7 @@ class APIView(View):
for permission in self.get_permissions():
if not permission.has_object_permission(request, self, obj):
self.permission_denied(
request, message=getattr(permission, 'message', None)
request, message=getattr(permission, "message", None)
)
def check_throttles(self, request):
@ -379,7 +379,7 @@ class APIView(View):
parsers=self.get_parsers(),
authenticators=self.get_authenticators(),
negotiator=self.get_content_negotiator(),
parser_context=parser_context
parser_context=parser_context,
)
def initial(self, request, *args, **kwargs):
@ -407,13 +407,12 @@ class APIView(View):
"""
# Make the error obvious if a proper response is not returned
assert isinstance(response, HttpResponseBase), (
'Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` '
'to be returned from the view, but received a `%s`'
% type(response)
"Expected a `Response`, `HttpResponse` or `HttpStreamingResponse` "
"to be returned from the view, but received a `%s`" % type(response)
)
if isinstance(response, Response):
if not getattr(request, 'accepted_renderer', None):
if not getattr(request, "accepted_renderer", None):
neg = self.perform_content_negotiation(request, force=True)
request.accepted_renderer, request.accepted_media_type = neg
@ -422,7 +421,7 @@ class APIView(View):
response.renderer_context = self.get_renderer_context()
# Add new vary headers to the response instead of overwriting.
vary_headers = self.headers.pop('Vary', None)
vary_headers = self.headers.pop("Vary", None)
if vary_headers is not None:
patch_vary_headers(response, cc_delim_re.split(vary_headers))
@ -436,8 +435,9 @@ class APIView(View):
Handle any exception that occurs, by returning an appropriate response,
or re-raising the error.
"""
if isinstance(exc, (exceptions.NotAuthenticated,
exceptions.AuthenticationFailed)):
if isinstance(
exc, (exceptions.NotAuthenticated, exceptions.AuthenticationFailed)
):
# WWW-Authenticate header for 401 responses, else coerce to 403
auth_header = self.get_authenticate_header(self.request)
@ -460,8 +460,8 @@ class APIView(View):
def raise_uncaught_exception(self, exc):
if settings.DEBUG:
request = self.request
renderer_format = getattr(request.accepted_renderer, 'format')
use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin')
renderer_format = getattr(request.accepted_renderer, "format")
use_plaintext_traceback = renderer_format not in ("html", "api", "admin")
request.force_plaintext_errors(use_plaintext_traceback)
raise exc
@ -484,8 +484,9 @@ class APIView(View):
# Get the appropriate handler method
if request.method.lower() in self.http_method_names:
handler = getattr(self, request.method.lower(),
self.http_method_not_allowed)
handler = getattr(
self, request.method.lower(), self.http_method_not_allowed
)
else:
handler = self.http_method_not_allowed

View File

@ -31,7 +31,7 @@ from rest_framework.reverse import reverse
def _is_extra_action(attr):
return hasattr(attr, 'mapping')
return hasattr(attr, "mapping")
class ViewSetMixin(object):
@ -73,24 +73,30 @@ class ViewSetMixin(object):
# actions must not be empty
if not actions:
raise TypeError("The `actions` argument must be provided when "
"calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`")
raise TypeError(
"The `actions` argument must be provided when "
"calling `.as_view()` on a ViewSet. For example "
"`.as_view({'get': 'list'})`"
)
# sanitize keyword arguments
for key in initkwargs:
if key in cls.http_method_names:
raise TypeError("You tried to pass in the %s method name as a "
"keyword argument to %s(). Don't do that."
% (key, cls.__name__))
raise TypeError(
"You tried to pass in the %s method name as a "
"keyword argument to %s(). Don't do that." % (key, cls.__name__)
)
if not hasattr(cls, key):
raise TypeError("%s() received an invalid keyword %r" % (
cls.__name__, key))
raise TypeError(
"%s() received an invalid keyword %r" % (cls.__name__, key)
)
# name and suffix are mutually exclusive
if 'name' in initkwargs and 'suffix' in initkwargs:
raise TypeError("%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__))
if "name" in initkwargs and "suffix" in initkwargs:
raise TypeError(
"%s() received both `name` and `suffix`, which are "
"mutually exclusive arguments." % (cls.__name__)
)
def view(request, *args, **kwargs):
self = cls(**initkwargs)
@ -105,7 +111,7 @@ class ViewSetMixin(object):
handler = getattr(self, action)
setattr(self, method, handler)
if hasattr(self, 'get') and not hasattr(self, 'head'):
if hasattr(self, "get") and not hasattr(self, "head"):
self.head = self.get
self.request = request
@ -136,11 +142,11 @@ class ViewSetMixin(object):
"""
request = super(ViewSetMixin, self).initialize_request(request, *args, **kwargs)
method = request.method.lower()
if method == 'options':
if method == "options":
# This is a special case as we always provide handling for the
# options method in the base `View` class.
# Unlike the other explicitly defined actions, 'metadata' is implicit.
self.action = 'metadata'
self.action = "metadata"
else:
self.action = self.action_map.get(method)
return request
@ -149,8 +155,8 @@ class ViewSetMixin(object):
"""
Reverse the action for the given `url_name`.
"""
url_name = '%s-%s' % (self.basename, url_name)
kwargs.setdefault('request', self.request)
url_name = "%s-%s" % (self.basename, url_name)
kwargs.setdefault("request", self.request)
return reverse(url_name, *args, **kwargs)
@ -175,13 +181,14 @@ class ViewSetMixin(object):
# filter for the relevant extra actions
actions = [
action for action in self.get_extra_actions()
action
for action in self.get_extra_actions()
if action.detail == self.detail
]
for action in actions:
try:
url_name = '%s-%s' % (self.basename, action.url_name)
url_name = "%s-%s" % (self.basename, action.url_name)
url = reverse(url_name, self.args, self.kwargs, request=self.request)
view = self.__class__(**action.kwargs)
action_urls[view.get_view_name()] = url
@ -195,6 +202,7 @@ class ViewSet(ViewSetMixin, views.APIView):
"""
The base ViewSet class does not provide any actions by default.
"""
pass
@ -204,26 +212,31 @@ class GenericViewSet(ViewSetMixin, generics.GenericAPIView):
but does include the base set of generic view behavior, such as
the `get_object` and `get_queryset` methods.
"""
pass
class ReadOnlyModelViewSet(mixins.RetrieveModelMixin,
mixins.ListModelMixin,
GenericViewSet):
class ReadOnlyModelViewSet(
mixins.RetrieveModelMixin, mixins.ListModelMixin, GenericViewSet
):
"""
A viewset that provides default `list()` and `retrieve()` actions.
"""
pass
class ModelViewSet(mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
GenericViewSet):
class ModelViewSet(
mixins.CreateModelMixin,
mixins.RetrieveModelMixin,
mixins.UpdateModelMixin,
mixins.DestroyModelMixin,
mixins.ListModelMixin,
GenericViewSet,
):
"""
A viewset that provides default `create()`, `retrieve()`, `update()`,
`partial_update()`, `destroy()` and `list()` actions.
"""
pass

View File

@ -6,16 +6,23 @@ import sys
import pytest
PYTEST_ARGS = {
'default': [],
'fast': ['-q'],
}
FLAKE8_ARGS = ['rest_framework', 'tests']
PYTEST_ARGS = {"default": [], "fast": ["-q"]}
ISORT_ARGS = ['--recursive', '--check-only', '--diff', '-o' 'uritemplate', '-p', 'tests', 'rest_framework', 'tests']
FLAKE8_ARGS = ["rest_framework", "tests"]
BLACK_ARGS = ['--check', '--verbose']
ISORT_ARGS = [
"--recursive",
"--check-only",
"--diff",
"-o" "uritemplate",
"-p",
"tests",
"rest_framework",
"tests",
]
BLACK_ARGS = ["--check", "--verbose"]
def exit_on_failure(ret, message=None):
@ -24,43 +31,48 @@ def exit_on_failure(ret, message=None):
def flake8_main(args):
print('Running flake8 code linting')
ret = subprocess.call(['flake8'] + args)
print('flake8 failed' if ret else 'flake8 passed')
print("Running flake8 code linting")
ret = subprocess.call(["flake8"] + args)
print("flake8 failed" if ret else "flake8 passed")
return ret
def isort_main(args):
print('Running isort code checking')
ret = subprocess.call(['isort'] + args)
print("Running isort code checking")
ret = subprocess.call(["isort"] + args)
if ret:
print('isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`')
print(
"isort failed: Some modules have incorrectly ordered imports. Fix by running `isort --recursive .`"
)
else:
print('isort passed')
print("isort passed")
return ret
def black_main(args):
print('Running black code checking')
ret = subprocess.call(['black', '.'] + args)
print("Running black code checking")
ret = subprocess.call(["black", "."] + args)
if ret:
print('black failed: Some code have incorrectly formatted. Fix by running `black .`')
print(
"black failed: Some code have incorrectly formatted. Fix by running `black .`"
)
else:
print('black passed')
print("black passed")
return ret
def split_class_and_function(string):
class_string, function_string = string.split('.', 1)
class_string, function_string = string.split(".", 1)
return "%s and %s" % (class_string, function_string)
def is_function(string):
# `True` if it looks like a test function is included in the string.
return string.startswith('test_') or '.test_' in string
return string.startswith("test_") or ".test_" in string
def is_class(string):
@ -70,7 +82,7 @@ def is_class(string):
if __name__ == "__main__":
try:
sys.argv.remove('--nolint')
sys.argv.remove("--nolint")
except ValueError:
run_black = True
run_flake8 = True
@ -81,18 +93,18 @@ if __name__ == "__main__":
run_isort = False
try:
sys.argv.remove('--lintonly')
sys.argv.remove("--lintonly")
except ValueError:
run_tests = True
else:
run_tests = False
try:
sys.argv.remove('--fast')
sys.argv.remove("--fast")
except ValueError:
style = 'default'
style = "default"
else:
style = 'fast'
style = "fast"
run_black = False
run_flake8 = False
run_isort = False
@ -102,26 +114,23 @@ if __name__ == "__main__":
first_arg = pytest_args[0]
try:
pytest_args.remove('--coverage')
pytest_args.remove("--coverage")
except ValueError:
pass
else:
pytest_args = [
'--cov', '.',
'--cov-report', 'xml',
] + pytest_args
pytest_args = ["--cov", ".", "--cov-report", "xml"] + pytest_args
if first_arg.startswith('-'):
if first_arg.startswith("-"):
# `runtests.py [flags]`
pytest_args = ['tests'] + pytest_args
pytest_args = ["tests"] + pytest_args
elif is_class(first_arg) and is_function(first_arg):
# `runtests.py TestCase.test_function [flags]`
expression = split_class_and_function(first_arg)
pytest_args = ['tests', '-k', expression] + pytest_args[1:]
pytest_args = ["tests", "-k", expression] + pytest_args[1:]
elif is_class(first_arg) or is_function(first_arg):
# `runtests.py TestCase [flags]`
# `runtests.py test_function [flags]`
pytest_args = ['tests', '-k', pytest_args[0]] + pytest_args[1:]
pytest_args = ["tests", "-k", pytest_args[0]] + pytest_args[1:]
else:
pytest_args = PYTEST_ARGS[style]

View File

@ -9,16 +9,23 @@ addopts=--tb=short --strict -ra
testspath = tests
[flake8]
ignore = E501
max-line-length = 120
ignore = E501, W503, E203
banned-modules = json = use from rest_framework.utils import json!
[isort]
skip=.tox
atomic=true
multi_line_output=5
known_standard_library=types
multi_line_output=3
lines_after_imports = 2
black=types
combine_as_imports = true
known_third_party=pytest,_pytest,django,pytz
known_first_party=rest_framework
known_first_party=rest_framework, tests
include_trailing_comma=true
line_length = 88
balanced_wrapping = true
sections = FUTURE, STDLIB, DJANGO, CMS, THIRDPARTY, FIRSTPARTY, LIB, LOCALFOLDER
[coverage:run]
# NOTE: source is ignored with pytest-cov (but uses the same).

View File

@ -10,21 +10,21 @@ from setuptools import find_packages, setup
def read(f):
return open(f, 'r', encoding='utf-8').read()
return open(f, "r", encoding="utf-8").read()
def get_version(package):
"""
Return package version as listed in `__version__` in `init.py`.
"""
init_py = open(os.path.join(package, '__init__.py')).read()
init_py = open(os.path.join(package, "__init__.py")).read()
return re.search("__version__ = ['\"]([^'\"]+)['\"]", init_py).group(1)
version = get_version('rest_framework')
version = get_version("rest_framework")
if sys.argv[-1] == 'publish':
if sys.argv[-1] == "publish":
if os.system("pip freeze | grep twine"):
print("twine not installed.\nUse `pip install twine`.\nExiting.")
sys.exit()
@ -33,48 +33,48 @@ if sys.argv[-1] == 'publish':
print("You probably want to also tag the version now:")
print(" git tag -a %s -m 'version %s'" % (version, version))
print(" git push --tags")
shutil.rmtree('dist')
shutil.rmtree('build')
shutil.rmtree('djangorestframework.egg-info')
shutil.rmtree("dist")
shutil.rmtree("build")
shutil.rmtree("djangorestframework.egg-info")
sys.exit()
setup(
name='djangorestframework',
name="djangorestframework",
version=version,
url='https://www.django-rest-framework.org/',
license='BSD',
description='Web APIs for Django, made easy.',
long_description=read('README.md'),
long_description_content_type='text/markdown',
author='Tom Christie',
author_email='tom@tomchristie.com', # SEE NOTE BELOW (*)
packages=find_packages(exclude=['tests*']),
url="https://www.django-rest-framework.org/",
license="BSD",
description="Web APIs for Django, made easy.",
long_description=read("README.md"),
long_description_content_type="text/markdown",
author="Tom Christie",
author_email="tom@tomchristie.com", # SEE NOTE BELOW (*)
packages=find_packages(exclude=["tests*"]),
include_package_data=True,
install_requires=[],
python_requires=">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*",
zip_safe=False,
classifiers=[
'Development Status :: 5 - Production/Stable',
'Environment :: Web Environment',
'Framework :: Django',
'Framework :: Django :: 1.11',
'Framework :: Django :: 2.0',
'Framework :: Django :: 2.1',
'Framework :: Django :: 2.2',
'Intended Audience :: Developers',
'License :: OSI Approved :: BSD License',
'Operating System :: OS Independent',
'Programming Language :: Python',
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
'Topic :: Internet :: WWW/HTTP',
]
"Development Status :: 5 - Production/Stable",
"Environment :: Web Environment",
"Framework :: Django",
"Framework :: Django :: 1.11",
"Framework :: Django :: 2.0",
"Framework :: Django :: 2.1",
"Framework :: Django :: 2.2",
"Intended Audience :: Developers",
"License :: OSI Approved :: BSD License",
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 2",
"Programming Language :: Python :: 2.7",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.4",
"Programming Language :: Python :: 3.5",
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Topic :: Internet :: WWW/HTTP",
],
)
# (*) Please direct queries to the discussion group, rather than to me directly

View File

@ -9,16 +9,22 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
migrations.swappable_dependency(settings.AUTH_USER_MODEL),
]
dependencies = [migrations.swappable_dependency(settings.AUTH_USER_MODEL)]
operations = [
migrations.CreateModel(
name='CustomToken',
name="CustomToken",
fields=[
('key', models.CharField(max_length=40, primary_key=True, serialize=False)),
('user', models.OneToOneField(on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL)),
(
"key",
models.CharField(max_length=40, primary_key=True, serialize=False),
),
(
"user",
models.OneToOneField(
on_delete=models.CASCADE, to=settings.AUTH_USER_MODEL
),
),
],
),
)
]

View File

@ -13,11 +13,18 @@ from django.test import TestCase, override_settings
from django.utils import six
from rest_framework import (
HTTP_HEADER_ENCODING, exceptions, permissions, renderers, status
HTTP_HEADER_ENCODING,
exceptions,
permissions,
renderers,
status,
)
from rest_framework.authentication import (
BaseAuthentication, BasicAuthentication, RemoteUserAuthentication,
SessionAuthentication, TokenAuthentication
BaseAuthentication,
BasicAuthentication,
RemoteUserAuthentication,
SessionAuthentication,
TokenAuthentication,
)
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.views import obtain_auth_token
@ -27,6 +34,7 @@ from rest_framework.views import APIView
from .models import CustomToken
factory = APIRequestFactory()
@ -35,92 +43,77 @@ class CustomTokenAuthentication(TokenAuthentication):
class CustomKeywordTokenAuthentication(TokenAuthentication):
keyword = 'Bearer'
keyword = "Bearer"
class MockView(APIView):
permission_classes = (permissions.IsAuthenticated,)
def get(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
return HttpResponse({"a": 1, "b": 2, "c": 3})
def post(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
return HttpResponse({"a": 1, "b": 2, "c": 3})
def put(self, request):
return HttpResponse({'a': 1, 'b': 2, 'c': 3})
return HttpResponse({"a": 1, "b": 2, "c": 3})
urlpatterns = [
url(
r'^session/$',
MockView.as_view(authentication_classes=[SessionAuthentication])
r"^session/$", MockView.as_view(authentication_classes=[SessionAuthentication])
),
url(r"^basic/$", MockView.as_view(authentication_classes=[BasicAuthentication])),
url(
r"^remote-user/$",
MockView.as_view(authentication_classes=[RemoteUserAuthentication]),
),
url(r"^token/$", MockView.as_view(authentication_classes=[TokenAuthentication])),
url(
r"^customtoken/$",
MockView.as_view(authentication_classes=[CustomTokenAuthentication]),
),
url(
r'^basic/$',
MockView.as_view(authentication_classes=[BasicAuthentication])
r"^customkeywordtoken/$",
MockView.as_view(authentication_classes=[CustomKeywordTokenAuthentication]),
),
url(
r'^remote-user/$',
MockView.as_view(authentication_classes=[RemoteUserAuthentication])
),
url(
r'^token/$',
MockView.as_view(authentication_classes=[TokenAuthentication])
),
url(
r'^customtoken/$',
MockView.as_view(authentication_classes=[CustomTokenAuthentication])
),
url(
r'^customkeywordtoken/$',
MockView.as_view(
authentication_classes=[CustomKeywordTokenAuthentication]
)
),
url(r'^auth-token/$', obtain_auth_token),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
url(r"^auth-token/$", obtain_auth_token),
url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")),
]
@override_settings(ROOT_URLCONF=__name__)
class BasicAuthTests(TestCase):
"""Basic authentication"""
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(
self.username, self.email, self.password
)
self.username = "john"
self.email = "lennon@thebeatles.com"
self.password = "password"
self.user = User.objects.create_user(self.username, self.email, self.password)
def test_post_form_passing_basic_auth(self):
"""Ensure POSTing json over basic auth with correct credentials passes and does not require CSRF"""
credentials = ('%s:%s' % (self.username, self.password))
credentials = "%s:%s" % (self.username, self.password)
base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING)
).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials
auth = "Basic %s" % base64_credentials
response = self.csrf_client.post(
'/basic/',
{'example': 'example'},
HTTP_AUTHORIZATION=auth
"/basic/", {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
def test_post_json_passing_basic_auth(self):
"""Ensure POSTing form over basic auth with correct credentials passes and does not require CSRF"""
credentials = ('%s:%s' % (self.username, self.password))
credentials = "%s:%s" % (self.username, self.password)
base64_credentials = base64.b64encode(
credentials.encode(HTTP_HEADER_ENCODING)
).decode(HTTP_HEADER_ENCODING)
auth = 'Basic %s' % base64_credentials
auth = "Basic %s" % base64_credentials
response = self.csrf_client.post(
'/basic/',
{'example': 'example'},
format='json',
HTTP_AUTHORIZATION=auth
"/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
@ -128,39 +121,34 @@ class BasicAuthTests(TestCase):
"""Ensure POSTing JSON over basic auth with incorrectly padded Base64 string is handled correctly"""
# regression test for issue in 'rest_framework.authentication.BasicAuthentication.authenticate'
# https://github.com/encode/django-rest-framework/issues/4089
auth = 'Basic =a='
auth = "Basic =a="
response = self.csrf_client.post(
'/basic/',
{'example': 'example'},
format='json',
HTTP_AUTHORIZATION=auth
"/basic/", {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_form_failing_basic_auth(self):
"""Ensure POSTing form over basic auth without correct credentials fails"""
response = self.csrf_client.post('/basic/', {'example': 'example'})
response = self.csrf_client.post("/basic/", {"example": "example"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_json_failing_basic_auth(self):
"""Ensure POSTing json over basic auth without correct credentials fails"""
response = self.csrf_client.post(
'/basic/',
{'example': 'example'},
format='json'
"/basic/", {"example": "example"}, format="json"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
assert response['WWW-Authenticate'] == 'Basic realm="api"'
assert response["WWW-Authenticate"] == 'Basic realm="api"'
def test_fail_post_if_credentials_are_missing(self):
response = self.csrf_client.post(
'/basic/', {'example': 'example'}, HTTP_AUTHORIZATION='Basic ')
"/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic "
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_credentials_contain_spaces(self):
response = self.csrf_client.post(
'/basic/', {'example': 'example'},
HTTP_AUTHORIZATION='Basic foo bar'
"/basic/", {"example": "example"}, HTTP_AUTHORIZATION="Basic foo bar"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -168,15 +156,14 @@ class BasicAuthTests(TestCase):
@override_settings(ROOT_URLCONF=__name__)
class SessionAuthTests(TestCase):
"""User session authentication"""
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.non_csrf_client = APIClient(enforce_csrf_checks=False)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(
self.username, self.email, self.password
)
self.username = "john"
self.email = "lennon@thebeatles.com"
self.password = "password"
self.user = User.objects.create_user(self.username, self.email, self.password)
def tearDown(self):
self.csrf_client.logout()
@ -187,8 +174,8 @@ class SessionAuthTests(TestCase):
cf. [#1810](https://github.com/encode/django-rest-framework/pull/1810)
"""
response = self.csrf_client.get('/auth/login/')
content = response.content.decode('utf8')
response = self.csrf_client.get("/auth/login/")
content = response.content.decode("utf8")
assert '<label for="id_username">Username:</label>' in content
def test_post_form_session_auth_failing_csrf(self):
@ -196,7 +183,7 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication without CSRF token fails.
"""
self.csrf_client.login(username=self.username, password=self.password)
response = self.csrf_client.post('/session/', {'example': 'example'})
response = self.csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_post_form_session_auth_passing_csrf(self):
@ -213,10 +200,9 @@ class SessionAuthTests(TestCase):
self.csrf_client.cookies[settings.CSRF_COOKIE_NAME] = token
# Post the token matching the cookie value
response = self.csrf_client.post('/session/', {
'example': 'example',
'csrfmiddlewaretoken': token,
})
response = self.csrf_client.post(
"/session/", {"example": "example", "csrfmiddlewaretoken": token}
)
assert response.status_code == status.HTTP_200_OK
def test_post_form_session_auth_passing(self):
@ -224,12 +210,8 @@ class SessionAuthTests(TestCase):
Ensure POSTing form over session authentication with logged in
user and CSRF token passes.
"""
self.non_csrf_client.login(
username=self.username, password=self.password
)
response = self.non_csrf_client.post(
'/session/', {'example': 'example'}
)
self.non_csrf_client.login(username=self.username, password=self.password)
response = self.non_csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_200_OK
def test_put_form_session_auth_passing(self):
@ -237,38 +219,33 @@ class SessionAuthTests(TestCase):
Ensure PUTting form over session authentication with
logged in user and CSRF token passes.
"""
self.non_csrf_client.login(
username=self.username, password=self.password
)
response = self.non_csrf_client.put(
'/session/', {'example': 'example'}
)
self.non_csrf_client.login(username=self.username, password=self.password)
response = self.non_csrf_client.put("/session/", {"example": "example"})
assert response.status_code == status.HTTP_200_OK
def test_post_form_session_auth_failing(self):
"""
Ensure POSTing form over session authentication without logged in user fails.
"""
response = self.csrf_client.post('/session/', {'example': 'example'})
response = self.csrf_client.post("/session/", {"example": "example"})
assert response.status_code == status.HTTP_403_FORBIDDEN
class BaseTokenAuthTests(object):
"""Token authentication"""
model = None
path = None
header_prefix = 'Token '
header_prefix = "Token "
def setUp(self):
self.csrf_client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(
self.username, self.email, self.password
)
self.username = "john"
self.email = "lennon@thebeatles.com"
self.password = "password"
self.user = User.objects.create_user(self.username, self.email, self.password)
self.key = 'abcd1234'
self.key = "abcd1234"
self.token = self.model.objects.create(key=self.key, user=self.user)
def test_post_form_passing_token_auth(self):
@ -278,39 +255,41 @@ class BaseTokenAuthTests(object):
"""
auth = self.header_prefix + self.key
response = self.csrf_client.post(
self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth
self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
def test_fail_authentication_if_user_is_not_active(self):
user = User.objects.create_user('foo', 'bar', 'baz')
user = User.objects.create_user("foo", "bar", "baz")
user.is_active = False
user.save()
self.model.objects.create(key='foobar_token', user=user)
self.model.objects.create(key="foobar_token", user=user)
response = self.csrf_client.post(
self.path, {'example': 'example'},
HTTP_AUTHORIZATION=self.header_prefix + 'foobar_token'
self.path,
{"example": "example"},
HTTP_AUTHORIZATION=self.header_prefix + "foobar_token",
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_form_passing_nonexistent_token_auth(self):
# use a nonexistent token key
auth = self.header_prefix + 'wxyz6789'
auth = self.header_prefix + "wxyz6789"
response = self.csrf_client.post(
self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth
self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_token_is_missing(self):
response = self.csrf_client.post(
self.path, {'example': 'example'},
HTTP_AUTHORIZATION=self.header_prefix)
self.path, {"example": "example"}, HTTP_AUTHORIZATION=self.header_prefix
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_fail_post_if_token_contains_spaces(self):
response = self.csrf_client.post(
self.path, {'example': 'example'},
HTTP_AUTHORIZATION=self.header_prefix + 'foo bar'
self.path,
{"example": "example"},
HTTP_AUTHORIZATION=self.header_prefix + "foo bar",
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -318,7 +297,7 @@ class BaseTokenAuthTests(object):
# add an 'invalid' unicode character
auth = self.header_prefix + self.key + "¸"
response = self.csrf_client.post(
self.path, {'example': 'example'}, HTTP_AUTHORIZATION=auth
self.path, {"example": "example"}, HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -329,8 +308,7 @@ class BaseTokenAuthTests(object):
"""
auth = self.header_prefix + self.key
response = self.csrf_client.post(
self.path, {'example': 'example'},
format='json', HTTP_AUTHORIZATION=auth
self.path, {"example": "example"}, format="json", HTTP_AUTHORIZATION=auth
)
assert response.status_code == status.HTTP_200_OK
@ -343,8 +321,10 @@ class BaseTokenAuthTests(object):
def func_to_test():
return self.csrf_client.post(
self.path, {'example': 'example'},
format='json', HTTP_AUTHORIZATION=auth
self.path,
{"example": "example"},
format="json",
HTTP_AUTHORIZATION=auth,
)
self.assertNumQueries(1, func_to_test)
@ -353,7 +333,7 @@ class BaseTokenAuthTests(object):
"""
Ensure POSTing form over token auth without correct credentials fails
"""
response = self.csrf_client.post(self.path, {'example': 'example'})
response = self.csrf_client.post(self.path, {"example": "example"})
assert response.status_code == status.HTTP_401_UNAUTHORIZED
def test_post_json_failing_token_auth(self):
@ -361,7 +341,7 @@ class BaseTokenAuthTests(object):
Ensure POSTing json over token auth without correct credentials fails
"""
response = self.csrf_client.post(
self.path, {'example': 'example'}, format='json'
self.path, {"example": "example"}, format="json"
)
assert response.status_code == status.HTTP_401_UNAUTHORIZED
@ -369,7 +349,7 @@ class BaseTokenAuthTests(object):
@override_settings(ROOT_URLCONF=__name__)
class TokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
path = '/token/'
path = "/token/"
def test_token_has_auto_assigned_key_if_none_provided(self):
"""Ensure creating a token with no key will auto-assign a key"""
@ -387,12 +367,12 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase):
"""Ensure token login view using JSON POST works."""
client = APIClient(enforce_csrf_checks=True)
response = client.post(
'/auth-token/',
{'username': self.username, 'password': self.password},
format='json'
"/auth-token/",
{"username": self.username, "password": self.password},
format="json",
)
assert response.status_code == status.HTTP_200_OK
assert response.data['token'] == self.key
assert response.data["token"] == self.key
def test_token_login_json_bad_creds(self):
"""
@ -401,41 +381,41 @@ class TokenAuthTests(BaseTokenAuthTests, TestCase):
"""
client = APIClient(enforce_csrf_checks=True)
response = client.post(
'/auth-token/',
{'username': self.username, 'password': "badpass"},
format='json'
"/auth-token/",
{"username": self.username, "password": "badpass"},
format="json",
)
assert response.status_code == 400
def test_token_login_json_missing_fields(self):
"""Ensure token login view using JSON POST fails if missing fields."""
client = APIClient(enforce_csrf_checks=True)
response = client.post('/auth-token/',
{'username': self.username}, format='json')
response = client.post(
"/auth-token/", {"username": self.username}, format="json"
)
assert response.status_code == 400
def test_token_login_form(self):
"""Ensure token login view using form POST works."""
client = APIClient(enforce_csrf_checks=True)
response = client.post(
'/auth-token/',
{'username': self.username, 'password': self.password}
"/auth-token/", {"username": self.username, "password": self.password}
)
assert response.status_code == status.HTTP_200_OK
assert response.data['token'] == self.key
assert response.data["token"] == self.key
@override_settings(ROOT_URLCONF=__name__)
class CustomTokenAuthTests(BaseTokenAuthTests, TestCase):
model = CustomToken
path = '/customtoken/'
path = "/customtoken/"
@override_settings(ROOT_URLCONF=__name__)
class CustomKeywordTokenAuthTests(BaseTokenAuthTests, TestCase):
model = Token
path = '/customkeywordtoken/'
header_prefix = 'Bearer '
path = "/customkeywordtoken/"
header_prefix = "Bearer "
class IncorrectCredentialsTests(TestCase):
@ -445,42 +425,42 @@ class IncorrectCredentialsTests(TestCase):
authentication should run and error, even if no permissions
are set on the view.
"""
class IncorrectCredentialsAuth(BaseAuthentication):
def authenticate(self, request):
raise exceptions.AuthenticationFailed('Bad credentials')
raise exceptions.AuthenticationFailed("Bad credentials")
request = factory.get('/')
request = factory.get("/")
view = MockView.as_view(
authentication_classes=(IncorrectCredentialsAuth,),
permission_classes=()
authentication_classes=(IncorrectCredentialsAuth,), permission_classes=()
)
response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.data == {'detail': 'Bad credentials'}
assert response.data == {"detail": "Bad credentials"}
class FailingAuthAccessedInRenderer(TestCase):
def setUp(self):
class AuthAccessingRenderer(renderers.BaseRenderer):
media_type = 'text/plain'
format = 'txt'
media_type = "text/plain"
format = "txt"
def render(self, data, media_type=None, renderer_context=None):
request = renderer_context['request']
request = renderer_context["request"]
if request.user.is_authenticated:
return b'authenticated'
return b'not authenticated'
return b"authenticated"
return b"not authenticated"
class FailingAuth(BaseAuthentication):
def authenticate(self, request):
raise exceptions.AuthenticationFailed('authentication failed')
raise exceptions.AuthenticationFailed("authentication failed")
class ExampleView(APIView):
authentication_classes = (FailingAuth,)
renderer_classes = (AuthAccessingRenderer,)
def get(self, request):
return Response({'foo': 'bar'})
return Response({"foo": "bar"})
self.view = ExampleView.as_view()
@ -490,10 +470,10 @@ class FailingAuthAccessedInRenderer(TestCase):
`request.user` without raising an exception. Particularly relevant
to HTML responses that might reasonably access `request.user`.
"""
request = factory.get('/')
request = factory.get("/")
response = self.view(request)
content = response.render().content
assert content == b'not authenticated'
assert content == b"not authenticated"
class NoAuthenticationClassesTests(TestCase):
@ -505,23 +485,21 @@ class NoAuthenticationClassesTests(TestCase):
"""
class DummyPermission(permissions.BasePermission):
message = 'Dummy permission message'
message = "Dummy permission message"
def has_permission(self, request, view):
return False
request = factory.get('/')
request = factory.get("/")
view = MockView.as_view(
authentication_classes=(),
permission_classes=(DummyPermission,),
authentication_classes=(), permission_classes=(DummyPermission,)
)
response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN
assert response.data == {'detail': 'Dummy permission message'}
assert response.data == {"detail": "Dummy permission message"}
class BasicAuthenticationUnitTests(TestCase):
def test_base_authentication_abstract_method(self):
with pytest.raises(NotImplementedError):
BaseAuthentication().authenticate({})
@ -529,34 +507,34 @@ class BasicAuthenticationUnitTests(TestCase):
def test_basic_authentication_raises_error_if_user_not_found(self):
auth = BasicAuthentication()
with pytest.raises(exceptions.AuthenticationFailed):
auth.authenticate_credentials('invalid id', 'invalid password')
auth.authenticate_credentials("invalid id", "invalid password")
def test_basic_authentication_raises_error_if_user_not_active(self):
from rest_framework import authentication
class MockUser(object):
is_active = False
old_authenticate = authentication.authenticate
authentication.authenticate = lambda **kwargs: MockUser()
auth = authentication.BasicAuthentication()
with pytest.raises(exceptions.AuthenticationFailed) as error:
auth.authenticate_credentials('foo', 'bar')
assert 'User inactive or deleted.' in str(error)
auth.authenticate_credentials("foo", "bar")
assert "User inactive or deleted." in str(error)
authentication.authenticate = old_authenticate
@override_settings(ROOT_URLCONF=__name__,
AUTHENTICATION_BACKENDS=('django.contrib.auth.backends.RemoteUserBackend',))
@override_settings(
ROOT_URLCONF=__name__,
AUTHENTICATION_BACKENDS=("django.contrib.auth.backends.RemoteUserBackend",),
)
class RemoteUserAuthenticationUnitTests(TestCase):
def setUp(self):
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(
self.username, self.email, self.password
)
self.username = "john"
self.email = "lennon@thebeatles.com"
self.password = "password"
self.user = User.objects.create_user(self.username, self.email, self.password)
def test_remote_user_works(self):
response = self.client.post('/remote-user/',
REMOTE_USER=self.username)
response = self.client.post("/remote-user/", REMOTE_USER=self.username)
self.assertEqual(response.status_code, status.HTTP_200_OK)

View File

@ -4,7 +4,8 @@ from django.conf.urls import include, url
from .views import MockView
urlpatterns = [
url(r'^$', MockView.as_view()),
url(r'^auth/', include('rest_framework.urls', namespace='rest_framework')),
url(r"^$", MockView.as_view()),
url(r"^auth/", include("rest_framework.urls", namespace="rest_framework")),
]

View File

@ -4,6 +4,5 @@ from django.conf.urls import url
from .views import MockView
urlpatterns = [
url(r'^$', MockView.as_view()),
]
urlpatterns = [url(r"^$", MockView.as_view())]

View File

@ -6,71 +6,65 @@ from django.test import TestCase, override_settings
from rest_framework.test import APIClient
@override_settings(ROOT_URLCONF='tests.browsable_api.auth_urls')
@override_settings(ROOT_URLCONF="tests.browsable_api.auth_urls")
class DropdownWithAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views enabled."""
def setUp(self):
self.client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(
self.username,
self.email,
self.password
)
self.username = "john"
self.email = "lennon@thebeatles.com"
self.password = "password"
self.user = User.objects.create_user(self.username, self.email, self.password)
def tearDown(self):
self.client.logout()
def test_name_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
response = self.client.get('/')
content = response.content.decode('utf8')
assert 'john' in content
response = self.client.get("/")
content = response.content.decode("utf8")
assert "john" in content
def test_logout_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
response = self.client.get('/')
content = response.content.decode('utf8')
assert '>Log out<' in content
response = self.client.get("/")
content = response.content.decode("utf8")
assert ">Log out<" in content
def test_login_shown_when_logged_out(self):
response = self.client.get('/')
content = response.content.decode('utf8')
assert '>Log in<' in content
response = self.client.get("/")
content = response.content.decode("utf8")
assert ">Log in<" in content
@override_settings(ROOT_URLCONF='tests.browsable_api.no_auth_urls')
@override_settings(ROOT_URLCONF="tests.browsable_api.no_auth_urls")
class NoDropdownWithoutAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views NOT enabled."""
def setUp(self):
self.client = APIClient(enforce_csrf_checks=True)
self.username = 'john'
self.email = 'lennon@thebeatles.com'
self.password = 'password'
self.user = User.objects.create_user(
self.username,
self.email,
self.password
)
self.username = "john"
self.email = "lennon@thebeatles.com"
self.password = "password"
self.user = User.objects.create_user(self.username, self.email, self.password)
def tearDown(self):
self.client.logout()
def test_name_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
response = self.client.get('/')
content = response.content.decode('utf8')
assert 'john' in content
response = self.client.get("/")
content = response.content.decode("utf8")
assert "john" in content
def test_dropdown_not_shown_when_logged_in(self):
self.client.login(username=self.username, password=self.password)
response = self.client.get('/')
content = response.content.decode('utf8')
response = self.client.get("/")
content = response.content.decode("utf8")
assert '<li class="dropdown">' not in content
def test_dropdown_not_shown_when_logged_out(self):
response = self.client.get('/')
content = response.content.decode('utf8')
response = self.client.get("/")
content = response.content.decode("utf8")
assert '<li class="dropdown">' not in content

View File

@ -19,24 +19,22 @@ class NestedSerializerTestSerializer(serializers.Serializer):
class NestedSerializersView(ListCreateAPIView):
renderer_classes = (BrowsableAPIRenderer, )
renderer_classes = (BrowsableAPIRenderer,)
serializer_class = NestedSerializerTestSerializer
queryset = [{'nested': {'one': 1, 'two': 2}}]
queryset = [{"nested": {"one": 1, "two": 2}}]
urlpatterns = [
url(r'^api/$', NestedSerializersView.as_view(), name='api'),
]
urlpatterns = [url(r"^api/$", NestedSerializersView.as_view(), name="api")]
class DropdownWithAuthTests(TestCase):
"""Tests correct dropdown behaviour with Auth views enabled."""
@override_settings(ROOT_URLCONF='tests.browsable_api.test_browsable_nested_api')
@override_settings(ROOT_URLCONF="tests.browsable_api.test_browsable_nested_api")
def test_login(self):
response = self.client.get('/api/')
response = self.client.get("/api/")
assert 200 == response.status_code
content = response.content.decode('utf-8')
content = response.content.decode("utf-8")
assert 'form action="/api/"' in content
assert 'input name="nested.one"' in content
assert 'input name="nested.two"' in content

View File

@ -5,13 +5,14 @@ from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from tests.models import BasicModel
factory = APIRequestFactory()
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
fields = '__all__'
fields = "__all__"
class StandardPostView(generics.CreateAPIView):
@ -39,19 +40,19 @@ class TestPostingListData(TestCase):
def test_json_response(self):
# sanity check for non-browsable API responses
view = StandardPostView.as_view()
request = factory.post('/', [{}], format='json')
request = factory.post("/", [{}], format="json")
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertTrue('non_field_errors' in response.data)
self.assertTrue("non_field_errors" in response.data)
def test_browsable_api(self):
view = StandardPostView.as_view()
request = factory.post('/?format=api', [{}], format='json')
request = factory.post("/?format=api", [{}], format="json")
response = view(request).render()
self.assertEqual(response.status_code, status.HTTP_400_BAD_REQUEST)
self.assertTrue('non_field_errors' in response.data)
self.assertTrue("non_field_errors" in response.data)
class TestManyPostView(TestCase):
@ -59,14 +60,11 @@ class TestManyPostView(TestCase):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
self.view = ManyPostView.as_view()
def test_post_many_post_view(self):
@ -77,7 +75,7 @@ class TestManyPostView(TestCase):
Regression test for https://github.com/encode/django-rest-framework/pull/3164
"""
data = {}
request = factory.post('/', data, format='json')
request = factory.post("/", data, format="json")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK

View File

@ -10,4 +10,4 @@ class MockView(APIView):
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get(self, request):
return Response({'a': 1, 'b': 2, 'c': 3})
return Response({"a": 1, "b": 2, "c": 3})

View File

@ -6,13 +6,21 @@ from django.core import management
def pytest_addoption(parser):
parser.addoption('--no-pkgroot', action='store_true', default=False,
help='Remove package root directory from sys.path, ensuring that '
'rest_framework is imported from the installed site-packages. '
'Used for testing the distribution.')
parser.addoption('--staticfiles', action='store_true', default=False,
help='Run tests with static files collection, using manifest '
'staticfiles storage. Used for testing the distribution.')
parser.addoption(
"--no-pkgroot",
action="store_true",
default=False,
help="Remove package root directory from sys.path, ensuring that "
"rest_framework is imported from the installed site-packages. "
"Used for testing the distribution.",
)
parser.addoption(
"--staticfiles",
action="store_true",
default=False,
help="Run tests with static files collection, using manifest "
"staticfiles storage. Used for testing the distribution.",
)
def pytest_configure(config):
@ -21,49 +29,42 @@ def pytest_configure(config):
settings.configure(
DEBUG_PROPAGATE_EXCEPTIONS=True,
DATABASES={
'default': {
'ENGINE': 'django.db.backends.sqlite3',
'NAME': ':memory:'
}
"default": {"ENGINE": "django.db.backends.sqlite3", "NAME": ":memory:"}
},
SITE_ID=1,
SECRET_KEY='not very secret in tests',
SECRET_KEY="not very secret in tests",
USE_I18N=True,
USE_L10N=True,
STATIC_URL='/static/',
ROOT_URLCONF='tests.urls',
STATIC_URL="/static/",
ROOT_URLCONF="tests.urls",
TEMPLATES=[
{
'BACKEND': 'django.template.backends.django.DjangoTemplates',
'APP_DIRS': True,
'OPTIONS': {
"debug": True, # We want template errors to raise
}
},
"BACKEND": "django.template.backends.django.DjangoTemplates",
"APP_DIRS": True,
"OPTIONS": {"debug": True}, # We want template errors to raise
}
],
MIDDLEWARE=(
'django.middleware.common.CommonMiddleware',
'django.contrib.sessions.middleware.SessionMiddleware',
'django.contrib.auth.middleware.AuthenticationMiddleware',
'django.contrib.messages.middleware.MessageMiddleware',
"django.middleware.common.CommonMiddleware",
"django.contrib.sessions.middleware.SessionMiddleware",
"django.contrib.auth.middleware.AuthenticationMiddleware",
"django.contrib.messages.middleware.MessageMiddleware",
),
INSTALLED_APPS=(
'django.contrib.admin',
'django.contrib.auth',
'django.contrib.contenttypes',
'django.contrib.sessions',
'django.contrib.sites',
'django.contrib.staticfiles',
'rest_framework',
'rest_framework.authtoken',
'tests.authentication',
'tests.generic_relations',
'tests.importable',
'tests',
),
PASSWORD_HASHERS=(
'django.contrib.auth.hashers.MD5PasswordHasher',
"django.contrib.admin",
"django.contrib.auth",
"django.contrib.contenttypes",
"django.contrib.sessions",
"django.contrib.sites",
"django.contrib.staticfiles",
"rest_framework",
"rest_framework.authtoken",
"tests.authentication",
"tests.generic_relations",
"tests.importable",
"tests",
),
PASSWORD_HASHERS=("django.contrib.auth.hashers.MD5PasswordHasher",),
)
# guardian is optional
@ -74,28 +75,32 @@ def pytest_configure(config):
else:
settings.ANONYMOUS_USER_ID = -1
settings.AUTHENTICATION_BACKENDS = (
'django.contrib.auth.backends.ModelBackend',
'guardian.backends.ObjectPermissionBackend',
)
settings.INSTALLED_APPS += (
'guardian',
"django.contrib.auth.backends.ModelBackend",
"guardian.backends.ObjectPermissionBackend",
)
settings.INSTALLED_APPS += ("guardian",)
if config.getoption('--no-pkgroot'):
if config.getoption("--no-pkgroot"):
sys.path.pop(0)
# import rest_framework before pytest re-adds the package root directory.
import rest_framework
package_dir = os.path.join(os.getcwd(), 'rest_framework')
package_dir = os.path.join(os.getcwd(), "rest_framework")
assert not rest_framework.__file__.startswith(package_dir)
# Manifest storage will raise an exception if static files are not present (ie, a packaging failure).
if config.getoption('--staticfiles'):
if config.getoption("--staticfiles"):
import rest_framework
settings.STATIC_ROOT = os.path.join(os.path.dirname(rest_framework.__file__), 'static-root')
settings.STATICFILES_STORAGE = 'django.contrib.staticfiles.storage.ManifestStaticFilesStorage'
settings.STATIC_ROOT = os.path.join(
os.path.dirname(rest_framework.__file__), "static-root"
)
settings.STATICFILES_STORAGE = (
"django.contrib.staticfiles.storage.ManifestStaticFilesStorage"
)
django.setup()
if config.getoption('--staticfiles'):
management.call_command('collectstatic', verbosity=0, interactive=False)
if config.getoption("--staticfiles"):
management.call_command("collectstatic", verbosity=0, interactive=False)

View File

@ -5,32 +5,59 @@ class Migration(migrations.Migration):
initial = True
dependencies = [
('contenttypes', '0002_remove_content_type_name'),
]
dependencies = [("contenttypes", "0002_remove_content_type_name")]
operations = [
migrations.CreateModel(
name='Bookmark',
name="Bookmark",
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('url', models.URLField()),
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("url", models.URLField()),
],
),
migrations.CreateModel(
name='Note',
name="Note",
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('text', models.TextField()),
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("text", models.TextField()),
],
),
migrations.CreateModel(
name='Tag',
name="Tag",
fields=[
('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')),
('tag', models.SlugField()),
('object_id', models.PositiveIntegerField()),
('content_type', models.ForeignKey(on_delete=models.CASCADE, to='contenttypes.ContentType')),
(
"id",
models.AutoField(
auto_created=True,
primary_key=True,
serialize=False,
verbose_name="ID",
),
),
("tag", models.SlugField()),
("object_id", models.PositiveIntegerField()),
(
"content_type",
models.ForeignKey(
on_delete=models.CASCADE, to="contenttypes.ContentType"
),
),
],
),
]

View File

@ -1,8 +1,6 @@
from __future__ import unicode_literals
from django.contrib.contenttypes.fields import (
GenericForeignKey, GenericRelation
)
from django.contrib.contenttypes.fields import GenericForeignKey, GenericRelation
from django.contrib.contenttypes.models import ContentType
from django.db import models
from django.utils.encoding import python_2_unicode_compatible
@ -13,10 +11,11 @@ class Tag(models.Model):
"""
Tags have a descriptive slug, and are attached to an arbitrary object.
"""
tag = models.SlugField()
content_type = models.ForeignKey(ContentType, on_delete=models.CASCADE)
object_id = models.PositiveIntegerField()
tagged_item = GenericForeignKey('content_type', 'object_id')
tagged_item = GenericForeignKey("content_type", "object_id")
def __str__(self):
return self.tag
@ -27,11 +26,12 @@ class Bookmark(models.Model):
"""
A URL bookmark that may have multiple tags attached.
"""
url = models.URLField()
tags = GenericRelation(Tag)
def __str__(self):
return 'Bookmark: %s' % self.url
return "Bookmark: %s" % self.url
@python_2_unicode_compatible
@ -39,8 +39,9 @@ class Note(models.Model):
"""
A textual note that may have multiple tags attached.
"""
text = models.TextField()
tags = GenericRelation(Tag)
def __str__(self):
return 'Note: %s' % self.text
return "Note: %s" % self.text

View File

@ -9,11 +9,11 @@ from .models import Bookmark, Note, Tag
class TestGenericRelations(TestCase):
def setUp(self):
self.bookmark = Bookmark.objects.create(url='https://www.djangoproject.com/')
Tag.objects.create(tagged_item=self.bookmark, tag='django')
Tag.objects.create(tagged_item=self.bookmark, tag='python')
self.note = Note.objects.create(text='Remember the milk')
Tag.objects.create(tagged_item=self.note, tag='reminder')
self.bookmark = Bookmark.objects.create(url="https://www.djangoproject.com/")
Tag.objects.create(tagged_item=self.bookmark, tag="django")
Tag.objects.create(tagged_item=self.bookmark, tag="python")
self.note = Note.objects.create(text="Remember the milk")
Tag.objects.create(tagged_item=self.note, tag="reminder")
def test_generic_relation(self):
"""
@ -26,12 +26,12 @@ class TestGenericRelations(TestCase):
class Meta:
model = Bookmark
fields = ('tags', 'url')
fields = ("tags", "url")
serializer = BookmarkSerializer(self.bookmark)
expected = {
'tags': ['django', 'python'],
'url': 'https://www.djangoproject.com/'
"tags": ["django", "python"],
"url": "https://www.djangoproject.com/",
}
assert serializer.data == expected
@ -46,21 +46,18 @@ class TestGenericRelations(TestCase):
class Meta:
model = Tag
fields = ('tag', 'tagged_item')
fields = ("tag", "tagged_item")
serializer = TagSerializer(Tag.objects.all(), many=True)
expected = [
{
'tag': 'django',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
"tag": "django",
"tagged_item": "Bookmark: https://www.djangoproject.com/",
},
{
'tag': 'python',
'tagged_item': 'Bookmark: https://www.djangoproject.com/'
"tag": "python",
"tagged_item": "Bookmark: https://www.djangoproject.com/",
},
{
'tag': 'reminder',
'tagged_item': 'Note: Remember the milk'
}
{"tag": "reminder", "tagged_item": "Note: Remember the milk"},
]
assert serializer.data == expected

View File

@ -5,9 +5,9 @@ from tests import importable
def test_installed():
# ensure that apps can freely import rest_framework.compat
assert 'tests.importable' in settings.INSTALLED_APPS
assert "tests.importable" in settings.INSTALLED_APPS
def test_imported():
# ensure that the __init__ hasn't been mucked with
assert hasattr(importable, 'compat')
assert hasattr(importable, "compat")

View File

@ -12,7 +12,7 @@ class RESTFrameworkModel(models.Model):
"""
class Meta:
app_label = 'tests'
app_label = "tests"
abstract = True
@ -20,7 +20,7 @@ class BasicModel(RESTFrameworkModel):
text = models.CharField(
max_length=100,
verbose_name=_("Text comes here"),
help_text=_("Text description.")
help_text=_("Text description."),
)
@ -32,7 +32,7 @@ class ManyToManyTarget(RESTFrameworkModel):
class ManyToManySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
targets = models.ManyToManyField(ManyToManyTarget, related_name='sources')
targets = models.ManyToManyField(ManyToManyTarget, related_name="sources")
# ForeignKey
@ -47,51 +47,74 @@ class UUIDForeignKeyTarget(RESTFrameworkModel):
class ForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, related_name='sources',
help_text='Target', verbose_name='Target',
on_delete=models.CASCADE)
target = models.ForeignKey(
ForeignKeyTarget,
related_name="sources",
help_text="Target",
verbose_name="Target",
on_delete=models.CASCADE,
)
class ForeignKeySourceWithLimitedChoices(RESTFrameworkModel):
target = models.ForeignKey(ForeignKeyTarget, help_text='Target',
verbose_name='Target',
limit_choices_to={"name__startswith": "limited-"},
on_delete=models.CASCADE)
target = models.ForeignKey(
ForeignKeyTarget,
help_text="Target",
verbose_name="Target",
limit_choices_to={"name__startswith": "limited-"},
on_delete=models.CASCADE,
)
class ForeignKeySourceWithQLimitedChoices(RESTFrameworkModel):
target = models.ForeignKey(ForeignKeyTarget, help_text='Target',
verbose_name='Target',
limit_choices_to=models.Q(name__startswith="limited-"),
on_delete=models.CASCADE)
target = models.ForeignKey(
ForeignKeyTarget,
help_text="Target",
verbose_name="Target",
limit_choices_to=models.Q(name__startswith="limited-"),
on_delete=models.CASCADE,
)
# Nullable ForeignKey
class NullableForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
related_name='nullable_sources',
verbose_name='Optional target object',
on_delete=models.CASCADE)
target = models.ForeignKey(
ForeignKeyTarget,
null=True,
blank=True,
related_name="nullable_sources",
verbose_name="Optional target object",
on_delete=models.CASCADE,
)
class NullableUUIDForeignKeySource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.ForeignKey(ForeignKeyTarget, null=True, blank=True,
related_name='nullable_sources',
verbose_name='Optional target object',
on_delete=models.CASCADE)
target = models.ForeignKey(
ForeignKeyTarget,
null=True,
blank=True,
related_name="nullable_sources",
verbose_name="Optional target object",
on_delete=models.CASCADE,
)
class NestedForeignKeySource(RESTFrameworkModel):
"""
Used for testing FK chain. A -> B -> C.
"""
name = models.CharField(max_length=100)
target = models.ForeignKey(NullableForeignKeySource, null=True, blank=True,
related_name='nested_sources',
verbose_name='Intermediate target object',
on_delete=models.CASCADE)
target = models.ForeignKey(
NullableForeignKeySource,
null=True,
blank=True,
related_name="nested_sources",
verbose_name="Intermediate target object",
on_delete=models.CASCADE,
)
# OneToOne
@ -102,13 +125,21 @@ class OneToOneTarget(RESTFrameworkModel):
class NullableOneToOneSource(RESTFrameworkModel):
name = models.CharField(max_length=100)
target = models.OneToOneField(
OneToOneTarget, null=True, blank=True,
related_name='nullable_source', on_delete=models.CASCADE)
OneToOneTarget,
null=True,
blank=True,
related_name="nullable_source",
on_delete=models.CASCADE,
)
class OneToOnePKSource(RESTFrameworkModel):
""" Test model where the primary key is a OneToOneField with another model. """
name = models.CharField(max_length=100)
target = models.OneToOneField(
OneToOneTarget, primary_key=True,
related_name='required_source', on_delete=models.CASCADE)
OneToOneTarget,
primary_key=True,
related_name="required_source",
on_delete=models.CASCADE,
)

View File

@ -18,52 +18,75 @@ from rest_framework.views import APIView
def get_schema():
return coreapi.Document(
url='https://api.example.com/',
title='Example API',
url="https://api.example.com/",
title="Example API",
content={
'simple_link': coreapi.Link('/example/', description='example link'),
'headers': coreapi.Link('/headers/'),
'location': {
'query': coreapi.Link('/example/', fields=[
coreapi.Field(name='example', schema=coreschema.String(description='example field'))
]),
'form': coreapi.Link('/example/', action='post', fields=[
coreapi.Field(name='example')
]),
'body': coreapi.Link('/example/', action='post', fields=[
coreapi.Field(name='example', location='body')
]),
'path': coreapi.Link('/example/{id}', fields=[
coreapi.Field(name='id', location='path')
])
"simple_link": coreapi.Link("/example/", description="example link"),
"headers": coreapi.Link("/headers/"),
"location": {
"query": coreapi.Link(
"/example/",
fields=[
coreapi.Field(
name="example",
schema=coreschema.String(description="example field"),
)
],
),
"form": coreapi.Link(
"/example/", action="post", fields=[coreapi.Field(name="example")]
),
"body": coreapi.Link(
"/example/",
action="post",
fields=[coreapi.Field(name="example", location="body")],
),
"path": coreapi.Link(
"/example/{id}", fields=[coreapi.Field(name="id", location="path")]
),
},
'encoding': {
'multipart': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
coreapi.Field(name='example')
]),
'multipart-body': coreapi.Link('/example/', action='post', encoding='multipart/form-data', fields=[
coreapi.Field(name='example', location='body')
]),
'urlencoded': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
coreapi.Field(name='example')
]),
'urlencoded-body': coreapi.Link('/example/', action='post', encoding='application/x-www-form-urlencoded', fields=[
coreapi.Field(name='example', location='body')
]),
'raw_upload': coreapi.Link('/upload/', action='post', encoding='application/octet-stream', fields=[
coreapi.Field(name='example', location='body')
]),
"encoding": {
"multipart": coreapi.Link(
"/example/",
action="post",
encoding="multipart/form-data",
fields=[coreapi.Field(name="example")],
),
"multipart-body": coreapi.Link(
"/example/",
action="post",
encoding="multipart/form-data",
fields=[coreapi.Field(name="example", location="body")],
),
"urlencoded": coreapi.Link(
"/example/",
action="post",
encoding="application/x-www-form-urlencoded",
fields=[coreapi.Field(name="example")],
),
"urlencoded-body": coreapi.Link(
"/example/",
action="post",
encoding="application/x-www-form-urlencoded",
fields=[coreapi.Field(name="example", location="body")],
),
"raw_upload": coreapi.Link(
"/upload/",
action="post",
encoding="application/octet-stream",
fields=[coreapi.Field(name="example", location="body")],
),
},
'response': {
'download': coreapi.Link('/download/'),
'text': coreapi.Link('/text/')
}
}
"response": {
"download": coreapi.Link("/download/"),
"text": coreapi.Link("/text/"),
},
},
)
def _iterlists(querydict):
if hasattr(querydict, 'iterlists'):
if hasattr(querydict, "iterlists"):
return querydict.iterlists()
return querydict.lists()
@ -73,8 +96,7 @@ def _get_query_params(request):
# than one item is present for a given key.
return {
key: (value[0] if len(value) == 1 else value)
for key, value in
_iterlists(request.query_params)
for key, value in _iterlists(request.query_params)
}
@ -83,7 +105,7 @@ def _get_data(request):
return request.data
# Coerce multidict into regular dict, and remove files to
# make assertions simpler.
if hasattr(request.data, 'iterlists') or hasattr(request.data, 'lists'):
if hasattr(request.data, "iterlists") or hasattr(request.data, "lists"):
# Use a list value if a QueryDict contains multiple items for a key.
return {
key: value[0] if len(value) == 1 else value
@ -91,9 +113,7 @@ def _get_data(request):
if key not in request.FILES
}
return {
key: value
for key, value in request.data.items()
if key not in request.FILES
key: value for key, value in request.data.items() if key not in request.FILES
}
@ -101,7 +121,7 @@ def _get_files(request):
if not request.FILES:
return {}
return {
key: {'name': value.name, 'content': value.read()}
key: {"name": value.name, "content": value.read()}
for key, value in request.FILES.items()
}
@ -116,210 +136,207 @@ class SchemaView(APIView):
class ListView(APIView):
def get(self, request):
return Response({
'method': request.method,
'query_params': _get_query_params(request)
})
return Response(
{"method": request.method, "query_params": _get_query_params(request)}
)
def post(self, request):
if request.content_type:
content_type = request.content_type.split(';')[0]
content_type = request.content_type.split(";")[0]
else:
content_type = None
return Response({
'method': request.method,
'query_params': _get_query_params(request),
'data': _get_data(request),
'files': _get_files(request),
'content_type': content_type
})
return Response(
{
"method": request.method,
"query_params": _get_query_params(request),
"data": _get_data(request),
"files": _get_files(request),
"content_type": content_type,
}
)
class DetailView(APIView):
def get(self, request, id):
return Response({
'id': id,
'method': request.method,
'query_params': _get_query_params(request)
})
return Response(
{
"id": id,
"method": request.method,
"query_params": _get_query_params(request),
}
)
class UploadView(APIView):
parser_classes = [FileUploadParser]
def post(self, request):
return Response({
'method': request.method,
'files': _get_files(request),
'content_type': request.content_type
})
return Response(
{
"method": request.method,
"files": _get_files(request),
"content_type": request.content_type,
}
)
class DownloadView(APIView):
def get(self, request):
return HttpResponse('some file content', content_type='image/png')
return HttpResponse("some file content", content_type="image/png")
class TextView(APIView):
def get(self, request):
return HttpResponse('123', content_type='text/plain')
return HttpResponse("123", content_type="text/plain")
class HeadersView(APIView):
def get(self, request):
headers = {
key[5:].replace('_', '-'): value
key[5:].replace("_", "-"): value
for key, value in request.META.items()
if key.startswith('HTTP_')
if key.startswith("HTTP_")
}
return Response({
'method': request.method,
'headers': headers
})
return Response({"method": request.method, "headers": headers})
urlpatterns = [
url(r'^$', SchemaView.as_view()),
url(r'^example/$', ListView.as_view()),
url(r'^example/(?P<id>[0-9]+)/$', DetailView.as_view()),
url(r'^upload/$', UploadView.as_view()),
url(r'^download/$', DownloadView.as_view()),
url(r'^text/$', TextView.as_view()),
url(r'^headers/$', HeadersView.as_view()),
url(r"^$", SchemaView.as_view()),
url(r"^example/$", ListView.as_view()),
url(r"^example/(?P<id>[0-9]+)/$", DetailView.as_view()),
url(r"^upload/$", UploadView.as_view()),
url(r"^download/$", DownloadView.as_view()),
url(r"^text/$", TextView.as_view()),
url(r"^headers/$", HeadersView.as_view()),
]
@unittest.skipUnless(coreapi, 'coreapi not installed')
@override_settings(ROOT_URLCONF='tests.test_api_client')
@unittest.skipUnless(coreapi, "coreapi not installed")
@override_settings(ROOT_URLCONF="tests.test_api_client")
class APIClientTests(APITestCase):
def test_api_client(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
assert schema.title == 'Example API'
assert schema.url == 'https://api.example.com/'
assert schema['simple_link'].description == 'example link'
assert schema['location']['query'].fields[0].schema.description == 'example field'
data = client.action(schema, ['simple_link'])
expected = {
'method': 'GET',
'query_params': {}
}
schema = client.get("http://api.example.com/")
assert schema.title == "Example API"
assert schema.url == "https://api.example.com/"
assert schema["simple_link"].description == "example link"
assert (
schema["location"]["query"].fields[0].schema.description == "example field"
)
data = client.action(schema, ["simple_link"])
expected = {"method": "GET", "query_params": {}}
assert data == expected
def test_query_params(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'query'], params={'example': 123})
expected = {
'method': 'GET',
'query_params': {'example': '123'}
}
schema = client.get("http://api.example.com/")
data = client.action(schema, ["location", "query"], params={"example": 123})
expected = {"method": "GET", "query_params": {"example": "123"}}
assert data == expected
def test_session_headers(self):
client = CoreAPIClient()
client.session.headers.update({'X-Custom-Header': 'foo'})
schema = client.get('http://api.example.com/')
data = client.action(schema, ['headers'])
assert data['headers']['X-CUSTOM-HEADER'] == 'foo'
client.session.headers.update({"X-Custom-Header": "foo"})
schema = client.get("http://api.example.com/")
data = client.action(schema, ["headers"])
assert data["headers"]["X-CUSTOM-HEADER"] == "foo"
def test_query_params_with_multiple_values(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'query'], params={'example': [1, 2, 3]})
expected = {
'method': 'GET',
'query_params': {'example': ['1', '2', '3']}
}
schema = client.get("http://api.example.com/")
data = client.action(
schema, ["location", "query"], params={"example": [1, 2, 3]}
)
expected = {"method": "GET", "query_params": {"example": ["1", "2", "3"]}}
assert data == expected
def test_form_params(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'form'], params={'example': 123})
schema = client.get("http://api.example.com/")
data = client.action(schema, ["location", "form"], params={"example": 123})
expected = {
'method': 'POST',
'content_type': 'application/json',
'query_params': {},
'data': {'example': 123},
'files': {}
"method": "POST",
"content_type": "application/json",
"query_params": {},
"data": {"example": 123},
"files": {},
}
assert data == expected
def test_body_params(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'body'], params={'example': 123})
schema = client.get("http://api.example.com/")
data = client.action(schema, ["location", "body"], params={"example": 123})
expected = {
'method': 'POST',
'content_type': 'application/json',
'query_params': {},
'data': 123,
'files': {}
"method": "POST",
"content_type": "application/json",
"query_params": {},
"data": 123,
"files": {},
}
assert data == expected
def test_path_params(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['location', 'path'], params={'id': 123})
expected = {
'method': 'GET',
'query_params': {},
'id': '123'
}
schema = client.get("http://api.example.com/")
data = client.action(schema, ["location", "path"], params={"id": 123})
expected = {"method": "GET", "query_params": {}, "id": "123"}
assert data == expected
def test_multipart_encoding(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
with tempfile.NamedTemporaryFile() as temp:
temp.write(b'example file content')
temp.write(b"example file content")
temp.flush()
temp.seek(0)
name = os.path.basename(temp.name)
data = client.action(schema, ['encoding', 'multipart'], params={'example': temp})
data = client.action(
schema, ["encoding", "multipart"], params={"example": temp}
)
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {},
'files': {'example': {'name': name, 'content': 'example file content'}}
"method": "POST",
"content_type": "multipart/form-data",
"query_params": {},
"data": {},
"files": {"example": {"name": name, "content": "example file content"}},
}
assert data == expected
def test_multipart_encoding_no_file(self):
# When no file is included, multipart encoding should still be used.
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
data = client.action(schema, ['encoding', 'multipart'], params={'example': 123})
data = client.action(schema, ["encoding", "multipart"], params={"example": 123})
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {'example': '123'},
'files': {}
"method": "POST",
"content_type": "multipart/form-data",
"query_params": {},
"data": {"example": "123"},
"files": {},
}
assert data == expected
def test_multipart_encoding_multiple_values(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
data = client.action(schema, ['encoding', 'multipart'], params={'example': [1, 2, 3]})
data = client.action(
schema, ["encoding", "multipart"], params={"example": [1, 2, 3]}
)
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {'example': ['1', '2', '3']},
'files': {}
"method": "POST",
"content_type": "multipart/form-data",
"query_params": {},
"data": {"example": ["1", "2", "3"]},
"files": {},
}
assert data == expected
@ -328,17 +345,19 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
example = File(name='example.txt', content='123')
data = client.action(schema, ['encoding', 'multipart'], params={'example': example})
example = File(name="example.txt", content="123")
data = client.action(
schema, ["encoding", "multipart"], params={"example": example}
)
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {},
'files': {'example': {'name': 'example.txt', 'content': '123'}}
"method": "POST",
"content_type": "multipart/form-data",
"query_params": {},
"data": {},
"files": {"example": {"name": "example.txt", "content": "123"}},
}
assert data == expected
@ -346,17 +365,19 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
example = {'foo': File(name='example.txt', content='123'), 'bar': 'abc'}
data = client.action(schema, ['encoding', 'multipart-body'], params={'example': example})
example = {"foo": File(name="example.txt", content="123"), "bar": "abc"}
data = client.action(
schema, ["encoding", "multipart-body"], params={"example": example}
)
expected = {
'method': 'POST',
'content_type': 'multipart/form-data',
'query_params': {},
'data': {'bar': 'abc'},
'files': {'foo': {'name': 'example.txt', 'content': '123'}}
"method": "POST",
"content_type": "multipart/form-data",
"query_params": {},
"data": {"bar": "abc"},
"files": {"foo": {"name": "example.txt", "content": "123"}},
}
assert data == expected
@ -364,40 +385,48 @@ class APIClientTests(APITestCase):
def test_urlencoded_encoding(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': 123})
schema = client.get("http://api.example.com/")
data = client.action(
schema, ["encoding", "urlencoded"], params={"example": 123}
)
expected = {
'method': 'POST',
'content_type': 'application/x-www-form-urlencoded',
'query_params': {},
'data': {'example': '123'},
'files': {}
"method": "POST",
"content_type": "application/x-www-form-urlencoded",
"query_params": {},
"data": {"example": "123"},
"files": {},
}
assert data == expected
def test_urlencoded_encoding_multiple_values(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded'], params={'example': [1, 2, 3]})
schema = client.get("http://api.example.com/")
data = client.action(
schema, ["encoding", "urlencoded"], params={"example": [1, 2, 3]}
)
expected = {
'method': 'POST',
'content_type': 'application/x-www-form-urlencoded',
'query_params': {},
'data': {'example': ['1', '2', '3']},
'files': {}
"method": "POST",
"content_type": "application/x-www-form-urlencoded",
"query_params": {},
"data": {"example": ["1", "2", "3"]},
"files": {},
}
assert data == expected
def test_urlencoded_encoding_in_body(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
data = client.action(schema, ['encoding', 'urlencoded-body'], params={'example': {'foo': 123, 'bar': True}})
schema = client.get("http://api.example.com/")
data = client.action(
schema,
["encoding", "urlencoded-body"],
params={"example": {"foo": 123, "bar": True}},
)
expected = {
'method': 'POST',
'content_type': 'application/x-www-form-urlencoded',
'query_params': {},
'data': {'foo': '123', 'bar': 'true'},
'files': {}
"method": "POST",
"content_type": "application/x-www-form-urlencoded",
"query_params": {},
"data": {"foo": "123", "bar": "true"},
"files": {},
}
assert data == expected
@ -405,20 +434,22 @@ class APIClientTests(APITestCase):
def test_raw_upload(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
with tempfile.NamedTemporaryFile(delete=False) as temp:
temp.write(b'example file content')
temp.write(b"example file content")
temp.flush()
temp.seek(0)
name = os.path.basename(temp.name)
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': temp})
data = client.action(
schema, ["encoding", "raw_upload"], params={"example": temp}
)
expected = {
'method': 'POST',
'files': {'file': {'name': name, 'content': 'example file content'}},
'content_type': 'application/octet-stream'
"method": "POST",
"files": {"file": {"name": name, "content": "example file content"}},
"content_type": "application/octet-stream",
}
assert data == expected
@ -426,15 +457,17 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
example = File('example.txt', '123')
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
example = File("example.txt", "123")
data = client.action(
schema, ["encoding", "raw_upload"], params={"example": example}
)
expected = {
'method': 'POST',
'files': {'file': {'name': 'example.txt', 'content': '123'}},
'content_type': 'text/plain'
"method": "POST",
"files": {"file": {"name": "example.txt", "content": "123"}},
"content_type": "text/plain",
}
assert data == expected
@ -442,15 +475,17 @@ class APIClientTests(APITestCase):
from coreapi.utils import File
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
example = File('example.txt', '123', 'text/html')
data = client.action(schema, ['encoding', 'raw_upload'], params={'example': example})
example = File("example.txt", "123", "text/html")
data = client.action(
schema, ["encoding", "raw_upload"], params={"example": example}
)
expected = {
'method': 'POST',
'files': {'file': {'name': 'example.txt', 'content': '123'}},
'content_type': 'text/html'
"method": "POST",
"files": {"file": {"name": "example.txt", "content": "123"}},
"content_type": "text/html",
}
assert data == expected
@ -458,17 +493,17 @@ class APIClientTests(APITestCase):
def test_text_response(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
data = client.action(schema, ['response', 'text'])
data = client.action(schema, ["response", "text"])
expected = '123'
expected = "123"
assert data == expected
def test_download_response(self):
client = CoreAPIClient()
schema = client.get('http://api.example.com/')
schema = client.get("http://api.example.com/")
data = client.action(schema, ['response', 'download'])
assert data.basename == 'download.png'
assert data.read() == b'some file content'
data = client.action(schema, ["response", "download"])
assert data.basename == "download.png"
assert data.read() == b"some file content"

View File

@ -14,13 +14,14 @@ from rest_framework.test import APIRequestFactory
from rest_framework.views import APIView
from tests.models import BasicModel
factory = APIRequestFactory()
class BasicView(APIView):
def post(self, request, *args, **kwargs):
BasicModel.objects.create()
return Response({'method': 'GET'})
return Response({"method": "GET"})
class ErrorView(APIView):
@ -45,25 +46,23 @@ class NonAtomicAPIExceptionView(APIView):
raise Http404
urlpatterns = (
url(r'^$', NonAtomicAPIExceptionView.as_view()),
)
urlpatterns = (url(r"^$", NonAtomicAPIExceptionView.as_view()),)
@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
"'atomic' requires transactions and savepoints.",
)
class DBTransactionTests(TestCase):
def setUp(self):
self.view = BasicView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_no_exception_commit_transaction(self):
request = factory.post('/')
request = factory.post("/")
with self.assertNumQueries(1):
response = self.view(request)
@ -74,15 +73,15 @@ class DBTransactionTests(TestCase):
@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
"'atomic' requires transactions and savepoints.",
)
class DBTransactionErrorTests(TestCase):
def setUp(self):
self.view = ErrorView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_generic_exception_delegate_transaction_management(self):
"""
@ -91,7 +90,7 @@ class DBTransactionErrorTests(TestCase):
We let django deal with the transaction when it will catch the Exception.
"""
request = factory.post('/')
request = factory.post("/")
with self.assertNumQueries(3):
# 1 - begin savepoint
# 2 - insert
@ -104,21 +103,21 @@ class DBTransactionErrorTests(TestCase):
@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
"'atomic' requires transactions and savepoints.",
)
class DBTransactionAPIExceptionTests(TestCase):
def setUp(self):
self.view = APIExceptionView.as_view()
connections.databases['default']['ATOMIC_REQUESTS'] = True
connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_api_exception_rollback_transaction(self):
"""
Transaction is rollbacked by our transaction atomic block.
"""
request = factory.post('/')
request = factory.post("/")
num_queries = 4 if connection.features.can_release_savepoints else 3
with self.assertNumQueries(num_queries):
# 1 - begin savepoint
@ -134,18 +133,18 @@ class DBTransactionAPIExceptionTests(TestCase):
@unittest.skipUnless(
connection.features.uses_savepoints,
"'atomic' requires transactions and savepoints."
"'atomic' requires transactions and savepoints.",
)
@override_settings(ROOT_URLCONF='tests.test_atomic_requests')
@override_settings(ROOT_URLCONF="tests.test_atomic_requests")
class NonAtomicDBTransactionAPIExceptionTests(TransactionTestCase):
def setUp(self):
connections.databases['default']['ATOMIC_REQUESTS'] = True
connections.databases["default"]["ATOMIC_REQUESTS"] = True
def tearDown(self):
connections.databases['default']['ATOMIC_REQUESTS'] = False
connections.databases["default"]["ATOMIC_REQUESTS"] = False
def test_api_exception_rollback_transaction_non_atomic_view(self):
response = self.client.get('/')
response = self.client.get("/")
# without checking connection.in_atomic_block view raises 500
# due attempt to rollback without transaction

View File

@ -6,44 +6,43 @@ from django.test import TestCase
from django.utils.six import StringIO
from rest_framework.authtoken.admin import TokenAdmin
from rest_framework.authtoken.management.commands.drf_create_token import \
Command as AuthTokenCommand
from rest_framework.authtoken.management.commands.drf_create_token import (
Command as AuthTokenCommand,
)
from rest_framework.authtoken.models import Token
from rest_framework.authtoken.serializers import AuthTokenSerializer
from rest_framework.exceptions import ValidationError
class AuthTokenTests(TestCase):
def setUp(self):
self.site = site
self.user = User.objects.create_user(username='test_user')
self.token = Token.objects.create(key='test token', user=self.user)
self.user = User.objects.create_user(username="test_user")
self.token = Token.objects.create(key="test token", user=self.user)
def test_model_admin_displayed_fields(self):
mock_request = object()
token_admin = TokenAdmin(self.token, self.site)
assert token_admin.get_fields(mock_request) == ('user',)
assert token_admin.get_fields(mock_request) == ("user",)
def test_token_string_representation(self):
assert str(self.token) == 'test token'
assert str(self.token) == "test token"
def test_validate_raise_error_if_no_credentials_provided(self):
with pytest.raises(ValidationError):
AuthTokenSerializer().validate({})
def test_whitespace_in_password(self):
data = {'username': self.user.username, 'password': 'test pass '}
self.user.set_password(data['password'])
data = {"username": self.user.username, "password": "test pass "}
self.user.set_password(data["password"])
self.user.save()
assert AuthTokenSerializer(data=data).is_valid()
class AuthTokenCommandTests(TestCase):
def setUp(self):
self.site = site
self.user = User.objects.create_user(username='test_user')
self.user = User.objects.create_user(username="test_user")
def test_command_create_user_token(self):
token = AuthTokenCommand().create_user_token(self.user.username, False)
@ -53,7 +52,7 @@ class AuthTokenCommandTests(TestCase):
def test_command_create_user_token_invalid_user(self):
with pytest.raises(User.DoesNotExist):
AuthTokenCommand().create_user_token('not_existing_user', False)
AuthTokenCommand().create_user_token("not_existing_user", False)
def test_command_reset_user_token(self):
AuthTokenCommand().create_user_token(self.user.username, False)
@ -74,12 +73,12 @@ class AuthTokenCommandTests(TestCase):
def test_command_raising_error_for_invalid_user(self):
out = StringIO()
with pytest.raises(CommandError):
call_command('drf_create_token', 'not_existing_user', stdout=out)
call_command("drf_create_token", "not_existing_user", stdout=out)
def test_command_output(self):
out = StringIO()
call_command('drf_create_token', self.user.username, stdout=out)
call_command("drf_create_token", self.user.username, stdout=out)
token_saved = Token.objects.first()
self.assertIn('Generated token', out.getvalue())
self.assertIn("Generated token", out.getvalue())
self.assertIn(self.user.username, out.getvalue())
self.assertIn(token_saved.key, out.getvalue())

View File

@ -11,41 +11,43 @@ class TestSimpleBoundField:
serializer = ExampleSerializer()
assert serializer['text'].value == ''
assert serializer['text'].errors is None
assert serializer['text'].name == 'text'
assert serializer['amount'].value is None
assert serializer['amount'].errors is None
assert serializer['amount'].name == 'amount'
assert serializer["text"].value == ""
assert serializer["text"].errors is None
assert serializer["text"].name == "text"
assert serializer["amount"].value is None
assert serializer["amount"].errors is None
assert serializer["amount"].name == "amount"
def test_populated_bound_field(self):
class ExampleSerializer(serializers.Serializer):
text = serializers.CharField(max_length=100)
amount = serializers.IntegerField()
serializer = ExampleSerializer(data={'text': 'abc', 'amount': 123})
serializer = ExampleSerializer(data={"text": "abc", "amount": 123})
assert serializer.is_valid()
assert serializer['text'].value == 'abc'
assert serializer['text'].errors is None
assert serializer['text'].name == 'text'
assert serializer['amount'].value is 123
assert serializer['amount'].errors is None
assert serializer['amount'].name == 'amount'
assert serializer["text"].value == "abc"
assert serializer["text"].errors is None
assert serializer["text"].name == "text"
assert serializer["amount"].value is 123
assert serializer["amount"].errors is None
assert serializer["amount"].name == "amount"
def test_error_bound_field(self):
class ExampleSerializer(serializers.Serializer):
text = serializers.CharField(max_length=100)
amount = serializers.IntegerField()
serializer = ExampleSerializer(data={'text': 'x' * 1000, 'amount': 123})
serializer = ExampleSerializer(data={"text": "x" * 1000, "amount": 123})
serializer.is_valid()
assert serializer['text'].value == 'x' * 1000
assert serializer['text'].errors == ['Ensure this field has no more than 100 characters.']
assert serializer['text'].name == 'text'
assert serializer['amount'].value is 123
assert serializer['amount'].errors is None
assert serializer['amount'].name == 'amount'
assert serializer["text"].value == "x" * 1000
assert serializer["text"].errors == [
"Ensure this field has no more than 100 characters."
]
assert serializer["text"].name == "text"
assert serializer["amount"].value is 123
assert serializer["amount"].errors is None
assert serializer["amount"].name == "amount"
def test_delete_field(self):
class ExampleSerializer(serializers.Serializer):
@ -53,41 +55,45 @@ class TestSimpleBoundField:
amount = serializers.IntegerField()
serializer = ExampleSerializer()
del serializer.fields['text']
assert 'text' not in serializer.fields
del serializer.fields["text"]
assert "text" not in serializer.fields
def test_as_form_fields(self):
class ExampleSerializer(serializers.Serializer):
bool_field = serializers.BooleanField()
null_field = serializers.IntegerField(allow_null=True)
serializer = ExampleSerializer(data={'bool_field': False, 'null_field': None})
serializer = ExampleSerializer(data={"bool_field": False, "null_field": None})
assert serializer.is_valid()
assert serializer['bool_field'].as_form_field().value == ''
assert serializer['null_field'].as_form_field().value == ''
assert serializer["bool_field"].as_form_field().value == ""
assert serializer["null_field"].as_form_field().value == ""
def test_rendering_boolean_field(self):
from rest_framework.renderers import HTMLFormRenderer
class ExampleSerializer(serializers.Serializer):
bool_field = serializers.BooleanField(
style={'base_template': 'checkbox.html', 'template_pack': 'rest_framework/vertical'})
style={
"base_template": "checkbox.html",
"template_pack": "rest_framework/vertical",
}
)
serializer = ExampleSerializer(data={'bool_field': True})
serializer = ExampleSerializer(data={"bool_field": True})
assert serializer.is_valid()
renderer = HTMLFormRenderer()
rendered = renderer.render_field(serializer['bool_field'], {})
rendered = renderer.render_field(serializer["bool_field"], {})
expected_packed = (
'<divclass="form-group">'
'<divclass="checkbox">'
'<label>'
"<label>"
'<inputtype="checkbox"name="bool_field"value="true"checked>'
'Boolfield'
'</label>'
'</div>'
'</div>'
"Boolfield"
"</label>"
"</div>"
"</div>"
)
rendered_packed = ''.join(rendered.split())
rendered_packed = "".join(rendered.split())
assert rendered_packed == expected_packed
@ -103,15 +109,15 @@ class TestNestedBoundField:
serializer = ExampleSerializer()
assert serializer['text'].value == ''
assert serializer['text'].errors is None
assert serializer['text'].name == 'text'
assert serializer['nested']['more_text'].value == ''
assert serializer['nested']['more_text'].errors is None
assert serializer['nested']['more_text'].name == 'nested.more_text'
assert serializer['nested']['amount'].value is None
assert serializer['nested']['amount'].errors is None
assert serializer['nested']['amount'].name == 'nested.amount'
assert serializer["text"].value == ""
assert serializer["text"].errors is None
assert serializer["text"].name == "text"
assert serializer["nested"]["more_text"].value == ""
assert serializer["nested"]["more_text"].errors is None
assert serializer["nested"]["more_text"].name == "nested.more_text"
assert serializer["nested"]["amount"].value is None
assert serializer["nested"]["amount"].errors is None
assert serializer["nested"]["amount"].name == "nested.amount"
def test_as_form_fields(self):
class Nested(serializers.Serializer):
@ -121,10 +127,12 @@ class TestNestedBoundField:
class ExampleSerializer(serializers.Serializer):
nested = Nested()
serializer = ExampleSerializer(data={'nested': {'bool_field': False, 'null_field': None}})
serializer = ExampleSerializer(
data={"nested": {"bool_field": False, "null_field": None}}
)
assert serializer.is_valid()
assert serializer['nested']['bool_field'].as_form_field().value == ''
assert serializer['nested']['null_field'].as_form_field().value == ''
assert serializer["nested"]["bool_field"].as_form_field().value == ""
assert serializer["nested"]["null_field"].as_form_field().value == ""
def test_rendering_nested_fields_with_none_value(self):
from rest_framework.renderers import HTMLFormRenderer
@ -139,28 +147,30 @@ class TestNestedBoundField:
class ExampleSerializer(serializers.Serializer):
nested2 = Nested2()
serializer = ExampleSerializer(data={'nested2': {'nested1': None, 'text_field': 'test'}})
serializer = ExampleSerializer(
data={"nested2": {"nested1": None, "text_field": "test"}}
)
assert serializer.is_valid()
renderer = HTMLFormRenderer()
for field in serializer:
rendered = renderer.render_field(field, {})
expected_packed = (
'<fieldset>'
'<legend>Nested2</legend>'
'<fieldset>'
'<legend>Nested1</legend>'
"<fieldset>"
"<legend>Nested2</legend>"
"<fieldset>"
"<legend>Nested1</legend>"
'<divclass="form-group">'
'<label>Textfield</label>'
"<label>Textfield</label>"
'<inputname="nested2.nested1.text_field"class="form-control"type="text"value="">'
'</div>'
'</fieldset>'
"</div>"
"</fieldset>"
'<divclass="form-group">'
'<label>Textfield</label>'
"<label>Textfield</label>"
'<inputname="nested2.text_field"class="form-control"type="text"value="test">'
'</div>'
'</fieldset>'
"</div>"
"</fieldset>"
)
rendered_packed = ''.join(rendered.split())
rendered_packed = "".join(rendered.split())
assert rendered_packed == expected_packed
@ -170,7 +180,7 @@ class TestJSONBoundField:
json_field = serializers.JSONField()
data = QueryDict(mutable=True)
data.update({'json_field': '{"some": ["json"}'})
data.update({"json_field": '{"some": ["json"}'})
serializer = TestSerializer(data=data)
assert serializer.is_valid() is False
assert serializer['json_field'].as_form_field().value == '{"some": ["json"}'
assert serializer["json_field"].as_form_field().value == '{"some": ["json"}'

View File

@ -6,9 +6,16 @@ from django.test import TestCase
from rest_framework import RemovedInDRF310Warning, status
from rest_framework.authentication import BasicAuthentication
from rest_framework.decorators import (
action, api_view, authentication_classes, detail_route, list_route,
parser_classes, permission_classes, renderer_classes, schema,
throttle_classes
action,
api_view,
authentication_classes,
detail_route,
list_route,
parser_classes,
permission_classes,
renderer_classes,
schema,
throttle_classes,
)
from rest_framework.parsers import JSONParser
from rest_framework.permissions import IsAuthenticated
@ -21,7 +28,6 @@ from rest_framework.views import APIView
class DecoratorTestCase(TestCase):
def setUp(self):
self.factory = APIRequestFactory()
@ -38,7 +44,7 @@ class DecoratorTestCase(TestCase):
def view(request):
return Response()
request = self.factory.get('/')
request = self.factory.get("/")
self.assertRaises(AssertionError, view, request)
def test_api_view_incorrect_arguments(self):
@ -47,108 +53,102 @@ class DecoratorTestCase(TestCase):
"""
with self.assertRaises(AssertionError):
@api_view('GET')
@api_view("GET")
def view(request):
return Response()
def test_calling_method(self):
@api_view(['GET'])
@api_view(["GET"])
def view(request):
return Response({})
request = self.factory.get('/')
request = self.factory.get("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/')
request = self.factory.post("/")
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_put_method(self):
@api_view(['GET', 'PUT'])
@api_view(["GET", "PUT"])
def view(request):
return Response({})
request = self.factory.put('/')
request = self.factory.put("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/')
request = self.factory.post("/")
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_calling_patch_method(self):
@api_view(['GET', 'PATCH'])
@api_view(["GET", "PATCH"])
def view(request):
return Response({})
request = self.factory.patch('/')
request = self.factory.patch("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
request = self.factory.post('/')
request = self.factory.post("/")
response = view(request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
def test_renderer_classes(self):
@api_view(['GET'])
@api_view(["GET"])
@renderer_classes([JSONRenderer])
def view(request):
return Response({})
request = self.factory.get('/')
request = self.factory.get("/")
response = view(request)
assert isinstance(response.accepted_renderer, JSONRenderer)
def test_parser_classes(self):
@api_view(['GET'])
@api_view(["GET"])
@parser_classes([JSONParser])
def view(request):
assert len(request.parsers) == 1
assert isinstance(request.parsers[0], JSONParser)
return Response({})
request = self.factory.get('/')
request = self.factory.get("/")
view(request)
def test_authentication_classes(self):
@api_view(['GET'])
@api_view(["GET"])
@authentication_classes([BasicAuthentication])
def view(request):
assert len(request.authenticators) == 1
assert isinstance(request.authenticators[0], BasicAuthentication)
return Response({})
request = self.factory.get('/')
request = self.factory.get("/")
view(request)
def test_permission_classes(self):
@api_view(['GET'])
@api_view(["GET"])
@permission_classes([IsAuthenticated])
def view(request):
return Response({})
request = self.factory.get('/')
request = self.factory.get("/")
response = view(request)
assert response.status_code == status.HTTP_403_FORBIDDEN
def test_throttle_classes(self):
class OncePerDayUserThrottle(UserRateThrottle):
rate = '1/day'
rate = "1/day"
@api_view(['GET'])
@api_view(["GET"])
@throttle_classes([OncePerDayUserThrottle])
def view(request):
return Response({})
request = self.factory.get('/')
request = self.factory.get("/")
response = view(request)
assert response.status_code == status.HTTP_200_OK
@ -159,10 +159,11 @@ class DecoratorTestCase(TestCase):
"""
Checks CustomSchema class is set on view
"""
class CustomSchema(AutoSchema):
pass
@api_view(['GET'])
@api_view(["GET"])
@schema(CustomSchema())
def view(request):
return Response({})
@ -171,23 +172,23 @@ class DecoratorTestCase(TestCase):
class ActionDecoratorTestCase(TestCase):
def test_defaults(self):
@action(detail=True)
def test_action(request):
"""Description"""
assert test_action.mapping == {'get': 'test_action'}
assert test_action.mapping == {"get": "test_action"}
assert test_action.detail is True
assert test_action.url_path == 'test_action'
assert test_action.url_name == 'test-action'
assert test_action.url_path == "test_action"
assert test_action.url_name == "test-action"
assert test_action.kwargs == {
'name': 'Test action',
'description': 'Description',
"name": "Test action",
"description": "Description",
}
def test_detail_required(self):
with pytest.raises(AssertionError) as excinfo:
@action()
def test_action(request):
raise NotImplementedError
@ -201,6 +202,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError
for name in APIView.http_method_names:
def method():
raise NotImplementedError
@ -222,36 +224,30 @@ class ActionDecoratorTestCase(TestCase):
def test_action(request):
raise NotImplementedError
assert test_action.kwargs == {
'description': None,
'name': 'Test action',
}
assert test_action.kwargs == {"description": None, "name": "Test action"}
# name kwarg supersedes name generation
@action(detail=True, name='test name')
@action(detail=True, name="test name")
def test_action(request):
raise NotImplementedError
assert test_action.kwargs == {
'description': None,
'name': 'test name',
}
assert test_action.kwargs == {"description": None, "name": "test name"}
# suffix kwarg supersedes name generation
@action(detail=True, suffix='Suffix')
@action(detail=True, suffix="Suffix")
def test_action(request):
raise NotImplementedError
assert test_action.kwargs == {
'description': None,
'suffix': 'Suffix',
}
assert test_action.kwargs == {"description": None, "suffix": "Suffix"}
# name + suffix is a conflict.
with pytest.raises(TypeError) as excinfo:
action(detail=True, name='test name', suffix='Suffix')
action(detail=True, name="test name", suffix="Suffix")
assert str(excinfo.value) == "`name` and `suffix` are mutually exclusive arguments."
assert (
str(excinfo.value)
== "`name` and `suffix` are mutually exclusive arguments."
)
def test_method_mapping(self):
@action(detail=False)
@ -263,7 +259,7 @@ class ActionDecoratorTestCase(TestCase):
raise NotImplementedError
# The secondary handler methods should not have the action attributes
for name in ['mapping', 'detail', 'url_path', 'url_name', 'kwargs']:
for name in ["mapping", "detail", "url_path", "url_name", "kwargs"]:
assert hasattr(test_action, name) and not hasattr(test_action_post, name)
def test_method_mapping_already_mapped(self):
@ -273,6 +269,7 @@ class ActionDecoratorTestCase(TestCase):
msg = "Method 'get' has already been mapped to '.test_action'."
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.get
def test_action_get(request):
raise NotImplementedError
@ -282,15 +279,19 @@ class ActionDecoratorTestCase(TestCase):
def test_action():
raise NotImplementedError
msg = ("Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration.")
msg = (
"Method mapping does not behave like the property decorator. You "
"cannot use the same method name for each mapping declaration."
)
with self.assertRaisesMessage(AssertionError, msg):
@test_action.mapping.post
def test_action():
raise NotImplementedError
def test_detail_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record:
@detail_route()
def view(request):
raise NotImplementedError
@ -304,6 +305,7 @@ class ActionDecoratorTestCase(TestCase):
def test_list_route_deprecation(self):
with pytest.warns(RemovedInDRF310Warning) as record:
@list_route()
def view(request):
raise NotImplementedError
@ -318,9 +320,10 @@ class ActionDecoratorTestCase(TestCase):
def test_route_url_name_from_path(self):
# pre-3.8 behavior was to base the `url_name` off of the `url_path`
with pytest.warns(RemovedInDRF310Warning):
@list_route(url_path='foo_bar')
@list_route(url_path="foo_bar")
def view(request):
raise NotImplementedError
assert view.url_path == 'foo_bar'
assert view.url_name == 'foo-bar'
assert view.url_path == "foo_bar"
assert view.url_name == "foo-bar"

View File

@ -9,6 +9,7 @@ from rest_framework.compat import apply_markdown
from rest_framework.utils.formatting import dedent
from rest_framework.views import APIView
# We check that docstrings get nicely un-indented.
DESCRIPTION = """an example docstring
====================
@ -81,28 +82,34 @@ class TestViewNamesAndDescriptions(TestCase):
"""
Ensure view names are based on the class name.
"""
class MockView(APIView):
pass
assert MockView().get_view_name() == 'Mock'
assert MockView().get_view_name() == "Mock"
def test_view_name_uses_name_attribute(self):
class MockView(APIView):
name = 'Foo'
assert MockView().get_view_name() == 'Foo'
name = "Foo"
assert MockView().get_view_name() == "Foo"
def test_view_name_uses_suffix_attribute(self):
class MockView(APIView):
suffix = 'List'
assert MockView().get_view_name() == 'Mock List'
suffix = "List"
assert MockView().get_view_name() == "Mock List"
def test_view_name_preferences_name_over_suffix(self):
class MockView(APIView):
name = 'Foo'
suffix = 'List'
assert MockView().get_view_name() == 'Foo'
name = "Foo"
suffix = "List"
assert MockView().get_view_name() == "Foo"
def test_view_description_uses_docstring(self):
"""Ensure view descriptions are based on the docstring."""
class MockView(APIView):
"""an example docstring
====================
@ -130,23 +137,28 @@ class TestViewNamesAndDescriptions(TestCase):
def test_view_description_uses_description_attribute(self):
class MockView(APIView):
description = 'Foo'
assert MockView().get_view_description() == 'Foo'
description = "Foo"
assert MockView().get_view_description() == "Foo"
def test_view_description_allows_empty_description(self):
class MockView(APIView):
"""Description."""
description = ''
assert MockView().get_view_description() == ''
description = ""
assert MockView().get_view_description() == ""
def test_view_description_can_be_empty(self):
"""
Ensure that if a view has no docstring,
then it's description is the empty string.
"""
class MockView(APIView):
pass
assert MockView().get_view_description() == ''
assert MockView().get_view_description() == ""
def test_view_description_can_be_promise(self):
"""
@ -168,7 +180,7 @@ class TestViewNamesAndDescriptions(TestCase):
class MockView(APIView):
__doc__ = MockLazyStr("a gettext string")
assert MockView().get_view_description() == 'a gettext string'
assert MockView().get_view_description() == "a gettext string"
def test_markdown(self):
"""
@ -176,21 +188,17 @@ class TestViewNamesAndDescriptions(TestCase):
"""
if apply_markdown:
md_applied = apply_markdown(DESCRIPTION)
gte_21_match = (
md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE))
lt_21_match = (
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE) or
md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE))
gte_21_match = md_applied == (
MARKED_DOWN_gte_21 % MARKED_DOWN_HILITE
) or md_applied == (MARKED_DOWN_gte_21 % MARKED_DOWN_NOT_HILITE)
lt_21_match = md_applied == (
MARKED_DOWN_lt_21 % MARKED_DOWN_HILITE
) or md_applied == (MARKED_DOWN_lt_21 % MARKED_DOWN_NOT_HILITE)
assert gte_21_match or lt_21_match
def test_dedent_tabs():
result = 'first string\n\nsecond string'
result = "first string\n\nsecond string"
assert dedent(" first string\n\n second string") == result
assert dedent("first string\n\n second string") == result
assert dedent("\tfirst string\n\n\tsecond string") == result

View File

@ -37,7 +37,7 @@ class JSONEncoderTests(TestCase):
current_time = datetime.now()
assert self.encoder.default(current_time) == current_time.isoformat()
current_time_utc = current_time.replace(tzinfo=utc)
assert self.encoder.default(current_time_utc) == current_time.isoformat() + 'Z'
assert self.encoder.default(current_time_utc) == current_time.isoformat() + "Z"
def test_encode_time(self):
"""
@ -76,7 +76,7 @@ class JSONEncoderTests(TestCase):
unique_id = uuid4()
assert self.encoder.default(unique_id) == str(unique_id)
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
def test_encode_coreapi_raises_error(self):
"""
Tests encoding a coreapi objects raises proper error

View File

@ -6,13 +6,16 @@ from django.utils import six, translation
from django.utils.translation import ugettext_lazy as _
from rest_framework.exceptions import (
APIException, ErrorDetail, Throttled, _get_error_details, bad_request,
server_error
APIException,
ErrorDetail,
Throttled,
_get_error_details,
bad_request,
server_error,
)
class ExceptionTestCase(TestCase):
def test_get_error_details(self):
example = "string"
@ -20,91 +23,96 @@ class ExceptionTestCase(TestCase):
assert _get_error_details(lazy_example) == example
assert isinstance(
_get_error_details(lazy_example),
ErrorDetail
)
assert isinstance(_get_error_details(lazy_example), ErrorDetail)
assert _get_error_details({'nested': lazy_example})['nested'] == example
assert _get_error_details({"nested": lazy_example})["nested"] == example
assert isinstance(
_get_error_details({'nested': lazy_example})['nested'],
ErrorDetail
_get_error_details({"nested": lazy_example})["nested"], ErrorDetail
)
assert _get_error_details([[lazy_example]])[0][0] == example
assert isinstance(
_get_error_details([[lazy_example]])[0][0],
ErrorDetail
)
assert isinstance(_get_error_details([[lazy_example]])[0][0], ErrorDetail)
def test_get_full_details_with_throttling(self):
exception = Throttled()
assert exception.get_full_details() == {
'message': 'Request was throttled.', 'code': 'throttled'}
"message": "Request was throttled.",
"code": "throttled",
}
exception = Throttled(wait=2)
assert exception.get_full_details() == {
'message': 'Request was throttled. Expected available in {} seconds.'.format(2 if six.PY3 else 2.),
'code': 'throttled'}
"message": "Request was throttled. Expected available in {} seconds.".format(
2 if six.PY3 else 2.0
),
"code": "throttled",
}
exception = Throttled(wait=2, detail='Slow down!')
exception = Throttled(wait=2, detail="Slow down!")
assert exception.get_full_details() == {
'message': 'Slow down! Expected available in {} seconds.'.format(2 if six.PY3 else 2.),
'code': 'throttled'}
"message": "Slow down! Expected available in {} seconds.".format(
2 if six.PY3 else 2.0
),
"code": "throttled",
}
class ErrorDetailTests(TestCase):
def test_eq(self):
assert ErrorDetail('msg') == ErrorDetail('msg')
assert ErrorDetail('msg', 'code') == ErrorDetail('msg', code='code')
assert ErrorDetail("msg") == ErrorDetail("msg")
assert ErrorDetail("msg", "code") == ErrorDetail("msg", code="code")
assert ErrorDetail('msg') == 'msg'
assert ErrorDetail('msg', 'code') == 'msg'
assert ErrorDetail("msg") == "msg"
assert ErrorDetail("msg", "code") == "msg"
def test_ne(self):
assert ErrorDetail('msg1') != ErrorDetail('msg2')
assert ErrorDetail('msg') != ErrorDetail('msg', code='invalid')
assert ErrorDetail("msg1") != ErrorDetail("msg2")
assert ErrorDetail("msg") != ErrorDetail("msg", code="invalid")
assert ErrorDetail('msg1') != 'msg2'
assert ErrorDetail('msg1', 'code') != 'msg2'
assert ErrorDetail("msg1") != "msg2"
assert ErrorDetail("msg1", "code") != "msg2"
def test_repr(self):
assert repr(ErrorDetail('msg1')) == \
'ErrorDetail(string={!r}, code=None)'.format('msg1')
assert repr(ErrorDetail('msg1', 'code')) == \
'ErrorDetail(string={!r}, code={!r})'.format('msg1', 'code')
assert repr(
ErrorDetail("msg1")
) == "ErrorDetail(string={!r}, code=None)".format("msg1")
assert repr(
ErrorDetail("msg1", "code")
) == "ErrorDetail(string={!r}, code={!r})".format("msg1", "code")
def test_str(self):
assert str(ErrorDetail('msg1')) == 'msg1'
assert str(ErrorDetail('msg1', 'code')) == 'msg1'
assert str(ErrorDetail("msg1")) == "msg1"
assert str(ErrorDetail("msg1", "code")) == "msg1"
def test_hash(self):
assert hash(ErrorDetail('msg')) == hash('msg')
assert hash(ErrorDetail('msg', 'code')) == hash('msg')
assert hash(ErrorDetail("msg")) == hash("msg")
assert hash(ErrorDetail("msg", "code")) == hash("msg")
class TranslationTests(TestCase):
@translation.override('fr')
@translation.override("fr")
def test_message(self):
# this test largely acts as a sanity test to ensure the translation files are present.
self.assertEqual(_('A server error occurred.'), 'Une erreur du serveur est survenue.')
self.assertEqual(six.text_type(APIException()), 'Une erreur du serveur est survenue.')
self.assertEqual(
_("A server error occurred."), "Une erreur du serveur est survenue."
)
self.assertEqual(
six.text_type(APIException()), "Une erreur du serveur est survenue."
)
def test_server_error():
request = RequestFactory().get('/')
request = RequestFactory().get("/")
response = server_error(request)
assert response.status_code == 500
assert response["content-type"] == 'application/json'
assert response["content-type"] == "application/json"
def test_bad_request():
request = RequestFactory().get('/')
exception = Exception('Something went wrong — Not used')
request = RequestFactory().get("/")
exception = Exception("Something went wrong — Not used")
response = bad_request(request, exception)
assert response.status_code == 400
assert response["content-type"] == 'application/json'
assert response["content-type"] == "application/json"

File diff suppressed because it is too large Load Diff

View File

@ -14,6 +14,7 @@ from rest_framework import filters, generics, serializers
from rest_framework.compat import coreschema
from rest_framework.test import APIRequestFactory
factory = APIRequestFactory()
@ -30,7 +31,7 @@ class BaseFilterTests(TestCase):
with pytest.raises(NotImplementedError):
self.filter_backend.filter_queryset(None, None, None)
@pytest.mark.skipif(not coreschema, reason='coreschema is not installed')
@pytest.mark.skipif(not coreschema, reason="coreschema is not installed")
def test_get_schema_fields_checks_for_coreapi(self):
filters.coreapi = None
with pytest.raises(AssertionError):
@ -47,7 +48,7 @@ class SearchFilterModel(models.Model):
class SearchFilterSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModel
fields = '__all__'
fields = "__all__"
class SearchFilterTests(TestCase):
@ -59,12 +60,8 @@ class SearchFilterTests(TestCase):
# zzz cde
# ...
for idx in range(10):
title = 'z' * (idx + 1)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
title = "z" * (idx + 1)
text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
SearchFilterModel(title=title, text=text).save()
def test_search(self):
@ -72,14 +69,14 @@ class SearchFilterTests(TestCase):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
search_fields = ("title", "text")
view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'})
request = factory.get("/", {"search": "b"})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
{"id": 1, "title": "z", "text": "abc"},
{"id": 2, "title": "zz", "text": "bcd"},
]
def test_search_returns_same_queryset_if_no_search_fields_or_terms_provided(self):
@ -89,10 +86,11 @@ class SearchFilterTests(TestCase):
filter_backends = (filters.SearchFilter,)
view = SearchListView.as_view()
request = factory.get('/')
request = factory.get("/")
response = view(request)
expected = SearchFilterSerializer(SearchFilterModel.objects.all(),
many=True).data
expected = SearchFilterSerializer(
SearchFilterModel.objects.all(), many=True
).data
assert response.data == expected
def test_exact_search(self):
@ -100,59 +98,53 @@ class SearchFilterTests(TestCase):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text')
search_fields = ("=title", "text")
view = SearchListView.as_view()
request = factory.get('/', {'search': 'zzz'})
request = factory.get("/", {"search": "zzz"})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
assert response.data == [{"id": 3, "title": "zzz", "text": "cde"}]
def test_startswith_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', '^text')
search_fields = ("title", "^text")
view = SearchListView.as_view()
request = factory.get('/', {'search': 'b'})
request = factory.get("/", {"search": "b"})
response = view(request)
assert response.data == [
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
assert response.data == [{"id": 2, "title": "zz", "text": "bcd"}]
def test_regexp_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('$title', '$text')
search_fields = ("$title", "$text")
view = SearchListView.as_view()
request = factory.get('/', {'search': 'z{2} ^b'})
request = factory.get("/", {"search": "z{2} ^b"})
response = view(request)
assert response.data == [
{'id': 2, 'title': 'zz', 'text': 'bcd'}
]
assert response.data == [{"id": 2, "title": "zz", "text": "bcd"}]
def test_search_with_nonstandard_search_param(self):
with override_settings(REST_FRAMEWORK={'SEARCH_PARAM': 'query'}):
with override_settings(REST_FRAMEWORK={"SEARCH_PARAM": "query"}):
reload_module(filters)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title', 'text')
search_fields = ("title", "text")
view = SearchListView.as_view()
request = factory.get('/', {'query': 'b'})
request = factory.get("/", {"query": "b"})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'z', 'text': 'abc'},
{'id': 2, 'title': 'zz', 'text': 'bcd'}
{"id": 1, "title": "z", "text": "abc"},
{"id": 2, "title": "zz", "text": "bcd"},
]
reload_module(filters)
@ -161,26 +153,24 @@ class SearchFilterTests(TestCase):
class CustomSearchFilter(filters.SearchFilter):
# Filter that dynamically changes search fields
def get_search_fields(self, view, request):
if request.query_params.get('title_only'):
return ('$title',)
if request.query_params.get("title_only"):
return ("$title",)
return super(CustomSearchFilter, self).get_search_fields(view, request)
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.all()
serializer_class = SearchFilterSerializer
filter_backends = (CustomSearchFilter,)
search_fields = ('$title', '$text')
search_fields = ("$title", "$text")
view = SearchListView.as_view()
request = factory.get('/', {'search': r'^\w{3}$'})
request = factory.get("/", {"search": r"^\w{3}$"})
response = view(request)
assert len(response.data) == 10
request = factory.get('/', {'search': r'^\w{3}$', 'title_only': 'true'})
request = factory.get("/", {"search": r"^\w{3}$", "title_only": "true"})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'zzz', 'text': 'cde'}
]
assert response.data == [{"id": 3, "title": "zzz", "text": "cde"}]
class AttributeModel(models.Model):
@ -195,33 +185,31 @@ class SearchFilterModelFk(models.Model):
class SearchFilterFkSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModelFk
fields = '__all__'
fields = "__all__"
class SearchFilterFkTests(TestCase):
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%stitle" % prefix]
SearchFilterModelFk._meta, ["%stitle" % prefix]
)
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%stitle" % prefix, "%sattribute__label" % prefix]
["%stitle" % prefix, "%sattribute__label" % prefix],
)
def test_must_call_distinct_restores_meta_for_each_field(self):
# In this test case the attribute of the fk model comes first in the
# list of search fields.
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
SearchFilterModelFk._meta,
["%sattribute__label" % prefix, "%stitle" % prefix]
["%sattribute__label" % prefix, "%stitle" % prefix],
)
@ -234,7 +222,7 @@ class SearchFilterModelM2M(models.Model):
class SearchFilterM2MSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModelM2M
fields = '__all__'
fields = "__all__"
class SearchFilterM2MTests(TestCase):
@ -246,43 +234,38 @@ class SearchFilterM2MTests(TestCase):
# zzz cde [1, 2, 3]
# ...
for idx in range(3):
label = 'w' * (idx + 1)
label = "w" * (idx + 1)
AttributeModel.objects.create(label=label)
for idx in range(10):
title = 'z' * (idx + 1)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
title = "z" * (idx + 1)
text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
SearchFilterModelM2M(title=title, text=text).save()
SearchFilterModelM2M.objects.get(title='zz').attributes.add(1, 2, 3)
SearchFilterModelM2M.objects.get(title="zz").attributes.add(1, 2, 3)
def test_m2m_search(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModelM2M.objects.all()
serializer_class = SearchFilterM2MSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=title', 'text', 'attributes__label')
search_fields = ("=title", "text", "attributes__label")
view = SearchListView.as_view()
request = factory.get('/', {'search': 'zz'})
request = factory.get("/", {"search": "zz"})
response = view(request)
assert len(response.data) == 1
def test_must_call_distinct(self):
filter_ = filters.SearchFilter()
prefixes = [''] + list(filter_.lookup_prefixes)
prefixes = [""] + list(filter_.lookup_prefixes)
for prefix in prefixes:
assert not filter_.must_call_distinct(
SearchFilterModelM2M._meta,
["%stitle" % prefix]
SearchFilterModelM2M._meta, ["%stitle" % prefix]
)
assert filter_.must_call_distinct(
SearchFilterModelM2M._meta,
["%stitle" % prefix, "%sattributes__label" % prefix]
["%stitle" % prefix, "%sattributes__label" % prefix],
)
@ -299,33 +282,46 @@ class Entry(models.Model):
class BlogSerializer(serializers.ModelSerializer):
class Meta:
model = Blog
fields = '__all__'
fields = "__all__"
class SearchFilterToManyTests(TestCase):
@classmethod
def setUpTestData(cls):
b1 = Blog.objects.create(name='Blog 1')
b2 = Blog.objects.create(name='Blog 2')
b1 = Blog.objects.create(name="Blog 1")
b2 = Blog.objects.create(name="Blog 2")
# Multiple entries on Lennon published in 1979 - distinct should deduplicate
Entry.objects.create(blog=b1, headline='Something about Lennon', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b1, headline='Another thing about Lennon', pub_date=datetime.date(1979, 6, 1))
Entry.objects.create(
blog=b1,
headline="Something about Lennon",
pub_date=datetime.date(1979, 1, 1),
)
Entry.objects.create(
blog=b1,
headline="Another thing about Lennon",
pub_date=datetime.date(1979, 6, 1),
)
# Entry on Lennon *and* a separate entry in 1979 - should not match
Entry.objects.create(blog=b2, headline='Something unrelated', pub_date=datetime.date(1979, 1, 1))
Entry.objects.create(blog=b2, headline='Retrospective on Lennon', pub_date=datetime.date(1990, 6, 1))
Entry.objects.create(
blog=b2, headline="Something unrelated", pub_date=datetime.date(1979, 1, 1)
)
Entry.objects.create(
blog=b2,
headline="Retrospective on Lennon",
pub_date=datetime.date(1990, 6, 1),
)
def test_multiple_filter_conditions(self):
class SearchListView(generics.ListAPIView):
queryset = Blog.objects.all()
serializer_class = BlogSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('=name', 'entry__headline', '=entry__pub_date__year')
search_fields = ("=name", "entry__headline", "=entry__pub_date__year")
view = SearchListView.as_view()
request = factory.get('/', {'search': 'Lennon,1979'})
request = factory.get("/", {"search": "Lennon,1979"})
response = view(request)
assert len(response.data) == 1
@ -335,60 +331,58 @@ class SearchFilterAnnotatedSerializer(serializers.ModelSerializer):
class Meta:
model = SearchFilterModel
fields = ('title', 'text', 'title_text')
fields = ("title", "text", "title_text")
class SearchFilterAnnotatedFieldTests(TestCase):
@classmethod
def setUpTestData(cls):
SearchFilterModel.objects.create(title='abc', text='def')
SearchFilterModel.objects.create(title='ghi', text='jkl')
SearchFilterModel.objects.create(title="abc", text="def")
SearchFilterModel.objects.create(title="ghi", text="jkl")
def test_search_in_annotated_field(self):
class SearchListView(generics.ListAPIView):
queryset = SearchFilterModel.objects.annotate(
title_text=Upper(
Concat(models.F('title'), models.F('text'))
)
title_text=Upper(Concat(models.F("title"), models.F("text")))
).all()
serializer_class = SearchFilterAnnotatedSerializer
filter_backends = (filters.SearchFilter,)
search_fields = ('title_text',)
search_fields = ("title_text",)
view = SearchListView.as_view()
request = factory.get('/', {'search': 'ABCDEF'})
request = factory.get("/", {"search": "ABCDEF"})
response = view(request)
assert len(response.data) == 1
assert response.data[0]['title_text'] == 'ABCDEF'
assert response.data[0]["title_text"] == "ABCDEF"
class OrderingFilterModel(models.Model):
title = models.CharField(max_length=20, verbose_name='verbose title')
title = models.CharField(max_length=20, verbose_name="verbose title")
text = models.CharField(max_length=100)
class OrderingFilterRelatedModel(models.Model):
related_object = models.ForeignKey(OrderingFilterModel, related_name="relateds", on_delete=models.CASCADE)
index = models.SmallIntegerField(help_text="A non-related field to test with", default=0)
related_object = models.ForeignKey(
OrderingFilterModel, related_name="relateds", on_delete=models.CASCADE
)
index = models.SmallIntegerField(
help_text="A non-related field to test with", default=0
)
class OrderingFilterSerializer(serializers.ModelSerializer):
class Meta:
model = OrderingFilterModel
fields = '__all__'
fields = "__all__"
class OrderingDottedRelatedSerializer(serializers.ModelSerializer):
related_text = serializers.CharField(source='related_object.text')
related_title = serializers.CharField(source='related_object.title')
related_text = serializers.CharField(source="related_object.text")
related_title = serializers.CharField(source="related_object.title")
class Meta:
model = OrderingFilterRelatedModel
fields = (
'related_text',
'related_title',
'index',
)
fields = ("related_text", "related_title", "index")
class DjangoFilterOrderingModel(models.Model):
@ -396,13 +390,13 @@ class DjangoFilterOrderingModel(models.Model):
text = models.CharField(max_length=10)
class Meta:
ordering = ['-date']
ordering = ["-date"]
class DjangoFilterOrderingSerializer(serializers.ModelSerializer):
class Meta:
model = DjangoFilterOrderingModel
fields = '__all__'
fields = "__all__"
class OrderingFilterTests(TestCase):
@ -413,16 +407,8 @@ class OrderingFilterTests(TestCase):
# yxw bcd
# xwv cde
for idx in range(3):
title = (
chr(ord('z') - idx) +
chr(ord('y') - idx) +
chr(ord('x') - idx)
)
text = (
chr(idx + ord('a')) +
chr(idx + ord('b')) +
chr(idx + ord('c'))
)
title = chr(ord("z") - idx) + chr(ord("y") - idx) + chr(ord("x") - idx)
text = chr(idx + ord("a")) + chr(idx + ord("b")) + chr(idx + ord("c"))
OrderingFilterModel(title=title, text=text).save()
def test_ordering(self):
@ -430,16 +416,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
ordering = ("title",)
ordering_fields = ("text",)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
request = factory.get("/", {"ordering": "text"})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{"id": 1, "title": "zyx", "text": "abc"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 3, "title": "xwv", "text": "cde"},
]
def test_reverse_ordering(self):
@ -447,16 +433,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
ordering = ("title",)
ordering_fields = ("text",)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-text'})
request = factory.get("/", {"ordering": "-text"})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{"id": 3, "title": "xwv", "text": "cde"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 1, "title": "zyx", "text": "abc"},
]
def test_incorrecturl_extrahyphens_ordering(self):
@ -464,16 +450,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
ordering = ("title",)
ordering_fields = ("text",)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '--text'})
request = factory.get("/", {"ordering": "--text"})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{"id": 3, "title": "xwv", "text": "cde"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 1, "title": "zyx", "text": "abc"},
]
def test_incorrectfield_ordering(self):
@ -481,16 +467,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
ordering = ("title",)
ordering_fields = ("text",)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'foobar'})
request = factory.get("/", {"ordering": "foobar"})
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{"id": 3, "title": "xwv", "text": "cde"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 1, "title": "zyx", "text": "abc"},
]
def test_default_ordering(self):
@ -498,16 +484,16 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
ordering = ("title",)
ordering_fields = ("text",)
view = OrderingListView.as_view()
request = factory.get('')
request = factory.get("")
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{"id": 3, "title": "xwv", "text": "cde"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 1, "title": "zyx", "text": "abc"},
]
def test_default_ordering_using_string(self):
@ -515,53 +501,48 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = ('text',)
ordering = "title"
ordering_fields = ("text",)
view = OrderingListView.as_view()
request = factory.get('')
request = factory.get("")
response = view(request)
assert response.data == [
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{"id": 3, "title": "xwv", "text": "cde"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 1, "title": "zyx", "text": "abc"},
]
def test_ordering_by_aggregate_field(self):
# create some related models to aggregate order by
num_objs = [2, 5, 3]
for obj, num_relateds in zip(OrderingFilterModel.objects.all(),
num_objs):
for obj, num_relateds in zip(OrderingFilterModel.objects.all(), num_objs):
for _ in range(num_relateds):
new_related = OrderingFilterRelatedModel(
related_object=obj
)
new_related = OrderingFilterRelatedModel(related_object=obj)
new_related.save()
class OrderingListView(generics.ListAPIView):
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = 'title'
ordering_fields = '__all__'
ordering = "title"
ordering_fields = "__all__"
queryset = OrderingFilterModel.objects.all().annotate(
models.Count("relateds"))
models.Count("relateds")
)
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'relateds__count'})
request = factory.get("/", {"ordering": "relateds__count"})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{"id": 1, "title": "zyx", "text": "abc"},
{"id": 3, "title": "xwv", "text": "cde"},
{"id": 2, "title": "yxw", "text": "bcd"},
]
def test_ordering_by_dotted_source(self):
for index, obj in enumerate(OrderingFilterModel.objects.all()):
OrderingFilterRelatedModel.objects.create(
related_object=obj,
index=index
)
OrderingFilterRelatedModel.objects.create(related_object=obj, index=index)
class OrderingListView(generics.ListAPIView):
serializer_class = OrderingDottedRelatedSerializer
@ -569,62 +550,62 @@ class OrderingFilterTests(TestCase):
queryset = OrderingFilterRelatedModel.objects.all()
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'related_object__text'})
request = factory.get("/", {"ordering": "related_object__text"})
response = view(request)
assert response.data == [
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
{"related_title": "zyx", "related_text": "abc", "index": 0},
{"related_title": "yxw", "related_text": "bcd", "index": 1},
{"related_title": "xwv", "related_text": "cde", "index": 2},
]
request = factory.get('/', {'ordering': '-index'})
request = factory.get("/", {"ordering": "-index"})
response = view(request)
assert response.data == [
{'related_title': 'xwv', 'related_text': 'cde', 'index': 2},
{'related_title': 'yxw', 'related_text': 'bcd', 'index': 1},
{'related_title': 'zyx', 'related_text': 'abc', 'index': 0},
{"related_title": "xwv", "related_text": "cde", "index": 2},
{"related_title": "yxw", "related_text": "bcd", "index": 1},
{"related_title": "zyx", "related_text": "abc", "index": 0},
]
def test_ordering_with_nonstandard_ordering_param(self):
with override_settings(REST_FRAMEWORK={'ORDERING_PARAM': 'order'}):
with override_settings(REST_FRAMEWORK={"ORDERING_PARAM": "order"}):
reload_module(filters)
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
serializer_class = OrderingFilterSerializer
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering_fields = ('text',)
ordering = ("title",)
ordering_fields = ("text",)
view = OrderingListView.as_view()
request = factory.get('/', {'order': 'text'})
request = factory.get("/", {"order": "text"})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{"id": 1, "title": "zyx", "text": "abc"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 3, "title": "xwv", "text": "cde"},
]
reload_module(filters)
def test_get_template_context(self):
class OrderingListView(generics.ListAPIView):
ordering_fields = '__all__'
ordering_fields = "__all__"
serializer_class = OrderingFilterSerializer
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
request = factory.get('/', {'ordering': 'title'}, HTTP_ACCEPT='text/html')
request = factory.get("/", {"ordering": "title"}, HTTP_ACCEPT="text/html")
view = OrderingListView.as_view()
response = view(request)
self.assertContains(response, 'verbose title')
self.assertContains(response, "verbose title")
def test_ordering_with_overridden_get_serializer_class(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering = ("title",)
# note: no ordering_fields and serializer_class specified
@ -632,24 +613,24 @@ class OrderingFilterTests(TestCase):
return OrderingFilterSerializer
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
request = factory.get("/", {"ordering": "text"})
response = view(request)
assert response.data == [
{'id': 1, 'title': 'zyx', 'text': 'abc'},
{'id': 2, 'title': 'yxw', 'text': 'bcd'},
{'id': 3, 'title': 'xwv', 'text': 'cde'},
{"id": 1, "title": "zyx", "text": "abc"},
{"id": 2, "title": "yxw", "text": "bcd"},
{"id": 3, "title": "xwv", "text": "cde"},
]
def test_ordering_with_improper_configuration(self):
class OrderingListView(generics.ListAPIView):
queryset = OrderingFilterModel.objects.all()
filter_backends = (filters.OrderingFilter,)
ordering = ('title',)
ordering = ("title",)
# note: no ordering_fields and serializer_class
# or get_serializer_class specified
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'text'})
request = factory.get("/", {"ordering": "text"})
with self.assertRaises(ImproperlyConfigured):
view(request)
@ -666,7 +647,7 @@ class SensitiveDataSerializer1(serializers.ModelSerializer):
class Meta:
model = SensitiveOrderingFilterModel
fields = ('id', 'username')
fields = ("id", "username")
class SensitiveDataSerializer2(serializers.ModelSerializer):
@ -675,74 +656,80 @@ class SensitiveDataSerializer2(serializers.ModelSerializer):
class Meta:
model = SensitiveOrderingFilterModel
fields = ('id', 'username', 'password')
fields = ("id", "username", "password")
class SensitiveDataSerializer3(serializers.ModelSerializer):
user = serializers.CharField(source='username')
user = serializers.CharField(source="username")
class Meta:
model = SensitiveOrderingFilterModel
fields = ('id', 'user')
fields = ("id", "user")
class SensitiveOrderingFilterTests(TestCase):
def setUp(self):
for idx in range(3):
username = {0: 'userA', 1: 'userB', 2: 'userC'}[idx]
password = {0: 'passA', 1: 'passC', 2: 'passB'}[idx]
username = {0: "userA", 1: "userB", 2: "userC"}[idx]
password = {0: "passA", 1: "passC", 2: "passB"}[idx]
SensitiveOrderingFilterModel(username=username, password=password).save()
def test_order_by_serializer_fields(self):
for serializer_cls in [
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
SensitiveDataSerializer3,
]:
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
queryset = SensitiveOrderingFilterModel.objects.all().order_by(
"username"
)
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': '-username'})
request = factory.get("/", {"ordering": "-username"})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
username_field = 'user'
username_field = "user"
else:
username_field = 'username'
username_field = "username"
# Note: Inverse username ordering correctly applied.
assert response.data == [
{'id': 3, username_field: 'userC'},
{'id': 2, username_field: 'userB'},
{'id': 1, username_field: 'userA'},
{"id": 3, username_field: "userC"},
{"id": 2, username_field: "userB"},
{"id": 1, username_field: "userA"},
]
def test_cannot_order_by_non_serializer_fields(self):
for serializer_cls in [
SensitiveDataSerializer1,
SensitiveDataSerializer2,
SensitiveDataSerializer3
SensitiveDataSerializer3,
]:
class OrderingListView(generics.ListAPIView):
queryset = SensitiveOrderingFilterModel.objects.all().order_by('username')
queryset = SensitiveOrderingFilterModel.objects.all().order_by(
"username"
)
filter_backends = (filters.OrderingFilter,)
serializer_class = serializer_cls
view = OrderingListView.as_view()
request = factory.get('/', {'ordering': 'password'})
request = factory.get("/", {"ordering": "password"})
response = view(request)
if serializer_cls == SensitiveDataSerializer3:
username_field = 'user'
username_field = "user"
else:
username_field = 'username'
username_field = "username"
# Note: The passwords are not in order. Default ordering is used.
assert response.data == [
{'id': 1, username_field: 'userA'}, # PassB
{'id': 2, username_field: 'userB'}, # PassC
{'id': 3, username_field: 'userC'}, # PassA
{"id": 1, username_field: "userA"}, # PassB
{"id": 2, username_field: "userB"}, # PassC
{"id": 3, username_field: "userC"}, # PassA
]

View File

@ -17,20 +17,18 @@ class FooView(APIView):
pass
urlpatterns = [
url(r'^$', FooView.as_view())
]
urlpatterns = [url(r"^$", FooView.as_view())]
@override_settings(ROOT_URLCONF='tests.test_generateschema')
@pytest.mark.skipif(not coreapi, reason='coreapi is not installed')
@override_settings(ROOT_URLCONF="tests.test_generateschema")
@pytest.mark.skipif(not coreapi, reason="coreapi is not installed")
class GenerateSchemaTests(TestCase):
"""Tests for management command generateschema."""
def setUp(self):
self.out = six.StringIO()
@pytest.mark.skipif(six.PY2, reason='PyYAML unicode output is malformed on PY2.')
@pytest.mark.skipif(six.PY2, reason="PyYAML unicode output is malformed on PY2.")
def test_renders_default_schema_with_custom_title_url_and_description(self):
expected_out = """info:
description: Sample description
@ -44,45 +42,29 @@ class GenerateSchemaTests(TestCase):
servers:
- url: http://api.sample.com/
"""
call_command('generateschema',
'--title=SampleAPI',
'--url=http://api.sample.com',
'--description=Sample description',
stdout=self.out)
call_command(
"generateschema",
"--title=SampleAPI",
"--url=http://api.sample.com",
"--description=Sample description",
stdout=self.out,
)
self.assertIn(formatting.dedent(expected_out), self.out.getvalue())
def test_renders_openapi_json_schema(self):
expected_out = {
"openapi": "3.0.0",
"info": {
"version": "",
"title": "",
"description": ""
},
"servers": [
{
"url": ""
}
],
"paths": {
"/": {
"get": {
"operationId": "list"
}
}
}
"info": {"version": "", "title": "", "description": ""},
"servers": [{"url": ""}],
"paths": {"/": {"get": {"operationId": "list"}}},
}
call_command('generateschema',
'--format=openapi-json',
stdout=self.out)
call_command("generateschema", "--format=openapi-json", stdout=self.out)
out_json = json.loads(self.out.getvalue())
self.assertDictEqual(out_json, expected_out)
def test_renders_corejson_schema(self):
expected_out = """{"_type":"document","":{"list":{"_type":"link","url":"/","action":"get"}}}"""
call_command('generateschema',
'--format=corejson',
stdout=self.out)
call_command("generateschema", "--format=corejson", stdout=self.out)
self.assertIn(expected_out, self.out.getvalue())

View File

@ -11,10 +11,14 @@ from rest_framework import generics, renderers, serializers, status
from rest_framework.response import Response
from rest_framework.test import APIRequestFactory
from tests.models import (
BasicModel, ForeignKeySource, ForeignKeyTarget, RESTFrameworkModel,
UUIDForeignKeyTarget
BasicModel,
ForeignKeySource,
ForeignKeyTarget,
RESTFrameworkModel,
UUIDForeignKeyTarget,
)
factory = APIRequestFactory()
@ -35,13 +39,13 @@ class Comment(RESTFrameworkModel):
class BasicSerializer(serializers.ModelSerializer):
class Meta:
model = BasicModel
fields = '__all__'
fields = "__all__"
class ForeignKeySerializer(serializers.ModelSerializer):
class Meta:
model = ForeignKeySource
fields = '__all__'
fields = "__all__"
class SlugSerializer(serializers.ModelSerializer):
@ -49,7 +53,7 @@ class SlugSerializer(serializers.ModelSerializer):
class Meta:
model = SlugBasedModel
fields = ('text', 'slug')
fields = ("text", "slug")
# Views
@ -59,7 +63,7 @@ class RootView(generics.ListCreateAPIView):
class InstanceView(generics.RetrieveUpdateDestroyAPIView):
queryset = BasicModel.objects.exclude(text='filtered out')
queryset = BasicModel.objects.exclude(text="filtered out")
serializer_class = BasicSerializer
@ -72,9 +76,10 @@ class SlugBasedInstanceView(InstanceView):
"""
A model with a slug-field.
"""
queryset = SlugBasedModel.objects.all()
serializer_class = SlugSerializer
lookup_field = 'slug'
lookup_field = "slug"
# Tests
@ -83,21 +88,18 @@ class TestRootView(TestCase):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
self.view = RootView.as_view()
def test_get_root_view(self):
"""
GET requests to ListCreateAPIView should return list of objects.
"""
request = factory.get('/')
request = factory.get("/")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
@ -107,7 +109,7 @@ class TestRootView(TestCase):
"""
HEAD requests to ListCreateAPIView should return 200.
"""
request = factory.head('/')
request = factory.head("/")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_200_OK
@ -116,21 +118,21 @@ class TestRootView(TestCase):
"""
POST requests to ListCreateAPIView should create a new object.
"""
data = {'text': 'foobar'}
request = factory.post('/', data, format='json')
data = {"text": "foobar"}
request = factory.post("/", data, format="json")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
assert response.data == {"id": 4, "text": "foobar"}
created = self.objects.get(id=4)
assert created.text == 'foobar'
assert created.text == "foobar"
def test_put_root_view(self):
"""
PUT requests to ListCreateAPIView should not be allowed
"""
data = {'text': 'foobar'}
request = factory.put('/', data, format='json')
data = {"text": "foobar"}
request = factory.put("/", data, format="json")
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@ -140,7 +142,7 @@ class TestRootView(TestCase):
"""
DELETE requests to ListCreateAPIView should not be allowed
"""
request = factory.delete('/')
request = factory.delete("/")
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@ -150,24 +152,24 @@ class TestRootView(TestCase):
"""
POST requests to create a new object should not be able to set the id.
"""
data = {'id': 999, 'text': 'foobar'}
request = factory.post('/', data, format='json')
data = {"id": 999, "text": "foobar"}
request = factory.post("/", data, format="json")
with self.assertNumQueries(1):
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
assert response.data == {'id': 4, 'text': 'foobar'}
assert response.data == {"id": 4, "text": "foobar"}
created = self.objects.get(id=4)
assert created.text == 'foobar'
assert created.text == "foobar"
def test_post_error_root_view(self):
"""
POST requests to ListCreateAPIView in HTML should include a form error.
"""
data = {'text': 'foobar' * 100}
request = factory.post('/', data, HTTP_ACCEPT='text/html')
data = {"text": "foobar" * 100}
request = factory.post("/", data, HTTP_ACCEPT="text/html")
response = self.view(request).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8')
assert expected_error in response.rendered_content.decode("utf-8")
EXPECTED_QUERIES_FOR_PUT = 2
@ -178,14 +180,11 @@ class TestInstanceView(TestCase):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz', 'filtered out']
items = ["foo", "bar", "baz", "filtered out"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects.exclude(text='filtered out')
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.objects = BasicModel.objects.exclude(text="filtered out")
self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
self.view = InstanceView.as_view()
self.slug_based_view = SlugBasedInstanceView.as_view()
@ -193,7 +192,7 @@ class TestInstanceView(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
request = factory.get("/1")
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
@ -203,8 +202,8 @@ class TestInstanceView(TestCase):
"""
POST requests to RetrieveUpdateDestroyAPIView should not be allowed
"""
data = {'text': 'foobar'}
request = factory.post('/', data, format='json')
data = {"text": "foobar"}
request = factory.post("/", data, format="json")
with self.assertNumQueries(0):
response = self.view(request).render()
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
@ -214,38 +213,38 @@ class TestInstanceView(TestCase):
"""
PUT requests to RetrieveUpdateDestroyAPIView should update an object.
"""
data = {'text': 'foobar'}
request = factory.put('/1', data, format='json')
data = {"text": "foobar"}
request = factory.put("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk='1').render()
response = self.view(request, pk="1").render()
assert response.status_code == status.HTTP_200_OK
assert dict(response.data) == {'id': 1, 'text': 'foobar'}
assert dict(response.data) == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1)
assert updated.text == 'foobar'
assert updated.text == "foobar"
def test_patch_instance_view(self):
"""
PATCH requests to RetrieveUpdateDestroyAPIView should update an object.
"""
data = {'text': 'foobar'}
request = factory.patch('/1', data, format='json')
data = {"text": "foobar"}
request = factory.patch("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'}
assert response.data == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1)
assert updated.text == 'foobar'
assert updated.text == "foobar"
def test_delete_instance_view(self):
"""
DELETE requests to RetrieveUpdateDestroyAPIView should delete an object.
"""
request = factory.delete('/1')
request = factory.delete("/1")
with self.assertNumQueries(2):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_204_NO_CONTENT
assert response.content == six.b('')
assert response.content == six.b("")
ids = [obj.id for obj in self.objects.all()]
assert ids == [2, 3]
@ -254,23 +253,23 @@ class TestInstanceView(TestCase):
GET requests with an incorrect pk type, should raise 404, not 500.
Regression test for #890.
"""
request = factory.get('/a')
request = factory.get("/a")
with self.assertNumQueries(0):
response = self.view(request, pk='a').render()
response = self.view(request, pk="a").render()
assert response.status_code == status.HTTP_404_NOT_FOUND
def test_put_cannot_set_id(self):
"""
PUT requests to create a new object should not be able to set the id.
"""
data = {'id': 999, 'text': 'foobar'}
request = factory.put('/1', data, format='json')
data = {"id": 999, "text": "foobar"}
request = factory.put("/1", data, format="json")
with self.assertNumQueries(EXPECTED_QUERIES_FOR_PUT):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foobar'}
assert response.data == {"id": 1, "text": "foobar"}
updated = self.objects.get(id=1)
assert updated.text == 'foobar'
assert updated.text == "foobar"
def test_put_to_deleted_instance(self):
"""
@ -278,8 +277,8 @@ class TestInstanceView(TestCase):
an object does not currently exist.
"""
self.objects.get(id=1).delete()
data = {'text': 'foobar'}
request = factory.put('/1', data, format='json')
data = {"text": "foobar"}
request = factory.put("/1", data, format="json")
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
@ -289,9 +288,9 @@ class TestInstanceView(TestCase):
PUT requests to an URL of instance which is filtered out should not be
able to create new objects.
"""
data = {'text': 'foo'}
filtered_out_pk = BasicModel.objects.filter(text='filtered out')[0].pk
request = factory.put('/{0}'.format(filtered_out_pk), data, format='json')
data = {"text": "foo"}
filtered_out_pk = BasicModel.objects.filter(text="filtered out")[0].pk
request = factory.put("/{0}".format(filtered_out_pk), data, format="json")
response = self.view(request, pk=filtered_out_pk).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
@ -299,8 +298,8 @@ class TestInstanceView(TestCase):
"""
PATCH requests should not be able to create objects.
"""
data = {'text': 'foobar'}
request = factory.patch('/999', data, format='json')
data = {"text": "foobar"}
request = factory.patch("/999", data, format="json")
with self.assertNumQueries(1):
response = self.view(request, pk=999).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
@ -310,11 +309,11 @@ class TestInstanceView(TestCase):
"""
Incorrect PUT requests in HTML should include a form error.
"""
data = {'text': 'foobar' * 100}
request = factory.put('/', data, HTTP_ACCEPT='text/html')
data = {"text": "foobar" * 100}
request = factory.put("/", data, HTTP_ACCEPT="text/html")
response = self.view(request, pk=1).render()
expected_error = '<span class="help-block">Ensure this field has no more than 100 characters.</span>'
assert expected_error in response.rendered_content.decode('utf-8')
assert expected_error in response.rendered_content.decode("utf-8")
class TestFKInstanceView(TestCase):
@ -322,17 +321,14 @@ class TestFKInstanceView(TestCase):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
items = ["foo", "bar", "baz"]
for item in items:
t = ForeignKeyTarget(name=item)
t.save()
ForeignKeySource(name='source_' + item, target=t).save()
ForeignKeySource(name="source_" + item, target=t).save()
self.objects = ForeignKeySource.objects
self.data = [
{'id': obj.id, 'name': obj.name}
for obj in self.objects.all()
]
self.data = [{"id": obj.id, "name": obj.name} for obj in self.objects.all()]
self.view = FKInstanceView.as_view()
@ -346,23 +342,21 @@ class TestOverriddenGetObject(TestCase):
"""
Create 3 BasicModel instances.
"""
items = ['foo', 'bar', 'baz']
items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
class OverriddenGetObjectView(generics.RetrieveUpdateDestroyAPIView):
"""
Example detail view for override of get_object().
"""
serializer_class = BasicSerializer
def get_object(self):
pk = int(self.kwargs['pk'])
pk = int(self.kwargs["pk"])
return get_object_or_404(BasicModel.objects.all(), id=pk)
self.view = OverriddenGetObjectView.as_view()
@ -371,7 +365,7 @@ class TestOverriddenGetObject(TestCase):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object.
"""
request = factory.get('/1')
request = factory.get("/1")
with self.assertNumQueries(1):
response = self.view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
@ -380,10 +374,11 @@ class TestOverriddenGetObject(TestCase):
# Regression test for #285
class CommentSerializer(serializers.ModelSerializer):
class Meta:
model = Comment
exclude = ('created',)
exclude = ("created",)
class CommentView(generics.ListCreateAPIView):
@ -402,12 +397,12 @@ class TestCreateModelWithAutoNowAddField(TestCase):
https://github.com/encode/django-rest-framework/issues/285
"""
data = {'email': 'foobar@example.com', 'content': 'foobar'}
request = factory.post('/', data, format='json')
data = {"email": "foobar@example.com", "content": "foobar"}
request = factory.post("/", data, format="json")
response = self.view(request).render()
assert response.status_code == status.HTTP_201_CREATED
created = self.objects.get(id=1)
assert created.content == 'foobar'
assert created.content == "foobar"
# Test for particularly ugly regression with m2m in browsable API
@ -427,7 +422,7 @@ class ClassASerializer(serializers.ModelSerializer):
class Meta:
model = ClassA
fields = '__all__'
fields = "__all__"
class ExampleView(generics.ListCreateAPIView):
@ -440,7 +435,7 @@ class TestM2MBrowsableAPI(TestCase):
"""
Test for particularly ugly regression with m2m in browsable API
"""
request = factory.get('/', HTTP_ACCEPT='text/html')
request = factory.get("/", HTTP_ACCEPT="text/html")
view = ExampleView().as_view()
response = view(request).render()
assert response.status_code == status.HTTP_200_OK
@ -448,12 +443,12 @@ class TestM2MBrowsableAPI(TestCase):
class InclusiveFilterBackend(object):
def filter_queryset(self, request, queryset, view):
return queryset.filter(text='foo')
return queryset.filter(text="foo")
class ExclusiveFilterBackend(object):
def filter_queryset(self, request, queryset, view):
return queryset.filter(text='other')
return queryset.filter(text="other")
class TwoFieldModel(models.Model):
@ -466,16 +461,20 @@ class DynamicSerializerView(generics.ListCreateAPIView):
renderer_classes = (renderers.BrowsableAPIRenderer, renderers.JSONRenderer)
def get_serializer_class(self):
if self.request.method == 'POST':
if self.request.method == "POST":
class DynamicSerializer(serializers.ModelSerializer):
class Meta:
model = TwoFieldModel
fields = ('field_b',)
fields = ("field_b",)
else:
class DynamicSerializer(serializers.ModelSerializer):
class Meta:
model = TwoFieldModel
fields = '__all__'
fields = "__all__"
return DynamicSerializer
@ -484,32 +483,29 @@ class TestFilterBackendAppliedToViews(TestCase):
"""
Create 3 BasicModel instances to filter on.
"""
items = ['foo', 'bar', 'baz']
items = ["foo", "bar", "baz"]
for item in items:
BasicModel(text=item).save()
self.objects = BasicModel.objects
self.data = [
{'id': obj.id, 'text': obj.text}
for obj in self.objects.all()
]
self.data = [{"id": obj.id, "text": obj.text} for obj in self.objects.all()]
def test_get_root_view_filters_by_name_with_filter_backend(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
root_view = RootView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/')
request = factory.get("/")
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert len(response.data) == 1
assert response.data == [{'id': 1, 'text': 'foo'}]
assert response.data == [{"id": 1, "text": "foo"}]
def test_get_root_view_filters_out_all_models_with_exclusive_filter_backend(self):
"""
GET requests to ListCreateAPIView should return empty list when all models are filtered out.
"""
root_view = RootView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/')
request = factory.get("/")
response = root_view(request).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == []
@ -519,31 +515,33 @@ class TestFilterBackendAppliedToViews(TestCase):
GET requests to RetrieveUpdateDestroyAPIView should raise 404 when model filtered out.
"""
instance_view = InstanceView.as_view(filter_backends=(ExclusiveFilterBackend,))
request = factory.get('/1')
request = factory.get("/1")
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {'detail': 'Not found.'}
assert response.data == {"detail": "Not found."}
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(self):
def test_get_instance_view_will_return_single_object_when_filter_does_not_exclude_it(
self
):
"""
GET requests to RetrieveUpdateDestroyAPIView should return a single object when not excluded
"""
instance_view = InstanceView.as_view(filter_backends=(InclusiveFilterBackend,))
request = factory.get('/1')
request = factory.get("/1")
response = instance_view(request, pk=1).render()
assert response.status_code == status.HTTP_200_OK
assert response.data == {'id': 1, 'text': 'foo'}
assert response.data == {"id": 1, "text": "foo"}
def test_dynamic_serializer_form_in_browsable_api(self):
"""
GET requests to ListCreateAPIView should return filtered list.
"""
view = DynamicSerializerView.as_view()
request = factory.get('/')
request = factory.get("/")
response = view(request).render()
content = response.content.decode('utf8')
assert 'field_b' in content
assert 'field_a' not in content
content = response.content.decode("utf8")
assert "field_b" in content
assert "field_a" not in content
class TestGuardedQueryset(TestCase):
@ -555,21 +553,21 @@ class TestGuardedQueryset(TestCase):
return Response(list(self.queryset))
view = QuerysetAccessError.as_view()
request = factory.get('/')
request = factory.get("/")
with pytest.raises(RuntimeError):
view(request).render()
class ApiViewsTests(TestCase):
def test_create_api_view_post(self):
class MockCreateApiView(generics.CreateAPIView):
def create(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockCreateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.post('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.post("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -578,9 +576,10 @@ class ApiViewsTests(TestCase):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockDestroyApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.delete('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.delete("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -589,9 +588,10 @@ class ApiViewsTests(TestCase):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.patch('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.patch("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -600,9 +600,10 @@ class ApiViewsTests(TestCase):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.get('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.get("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -611,9 +612,10 @@ class ApiViewsTests(TestCase):
def update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.put('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.put("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -622,9 +624,10 @@ class ApiViewsTests(TestCase):
def partial_update(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveUpdateApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.patch('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.patch("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -633,9 +636,10 @@ class ApiViewsTests(TestCase):
def retrieve(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.get('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.get("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -644,9 +648,10 @@ class ApiViewsTests(TestCase):
def destroy(self, request, *args, **kwargs):
self.called = True
self.call_args = (request, args, kwargs)
view = MockRetrieveDestroyUApiView()
data = ('test request', ('test arg',), {'test_kwarg': 'test'})
view.delete('test request', 'test arg', test_kwarg='test')
data = ("test request", ("test arg",), {"test_kwarg": "test"})
view.delete("test request", "test arg", test_kwarg="test")
assert view.called is True
assert view.call_args == data
@ -654,14 +659,12 @@ class ApiViewsTests(TestCase):
class GetObjectOr404Tests(TestCase):
def setUp(self):
super(GetObjectOr404Tests, self).setUp()
self.uuid_object = UUIDForeignKeyTarget.objects.create(name='bar')
self.uuid_object = UUIDForeignKeyTarget.objects.create(name="bar")
def test_get_object_or_404_with_valid_uuid(self):
obj = generics.get_object_or_404(
UUIDForeignKeyTarget, pk=self.uuid_object.pk
)
obj = generics.get_object_or_404(UUIDForeignKeyTarget, pk=self.uuid_object.pk)
assert obj == self.uuid_object
def test_get_object_or_404_with_invalid_string_for_uuid(self):
with pytest.raises(Http404):
generics.get_object_or_404(UUIDForeignKeyTarget, pk='not-a-uuid')
generics.get_object_or_404(UUIDForeignKeyTarget, pk="not-a-uuid")

View File

@ -15,40 +15,41 @@ from rest_framework.renderers import TemplateHTMLRenderer
from rest_framework.response import Response
@api_view(('GET',))
@api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,))
def example(request):
"""
A view that can returns an HTML representation.
"""
data = {'object': 'foobar'}
return Response(data, template_name='example.html')
data = {"object": "foobar"}
return Response(data, template_name="example.html")
@api_view(('GET',))
@api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,))
def permission_denied(request):
raise PermissionDenied()
@api_view(('GET',))
@api_view(("GET",))
@renderer_classes((TemplateHTMLRenderer,))
def not_found(request):
raise Http404()
urlpatterns = [
url(r'^$', example),
url(r'^permission_denied$', permission_denied),
url(r'^not_found$', not_found),
url(r"^$", example),
url(r"^permission_denied$", permission_denied),
url(r"^not_found$", not_found),
]
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
@override_settings(ROOT_URLCONF="tests.test_htmlrenderer")
class TemplateHTMLRendererTests(TestCase):
def setUp(self):
class MockResponse(object):
template_name = None
self.mock_response = MockResponse()
self._monkey_patch_get_template()
@ -59,13 +60,13 @@ class TemplateHTMLRendererTests(TestCase):
self.get_template = django.template.loader.get_template
def get_template(template_name, dirs=None):
if template_name == 'example.html':
return engines['django'].from_string("example: {{ object }}")
if template_name == "example.html":
return engines["django"].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name)
def select_template(template_name_list, dirs=None, using=None):
if template_name_list == ['example.html']:
return engines['django'].from_string("example: {{ object }}")
if template_name_list == ["example.html"]:
return engines["django"].from_string("example: {{ object }}")
raise TemplateDoesNotExist(template_name_list[0])
django.template.loader.get_template = get_template
@ -78,29 +79,29 @@ class TemplateHTMLRendererTests(TestCase):
django.template.loader.get_template = self.get_template
def test_simple_html_view(self):
response = self.client.get('/')
response = self.client.get("/")
self.assertContains(response, "example: foobar")
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_not_found_html_view(self):
response = self.client.get('/not_found')
response = self.client.get("/not_found")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertEqual(response.content, six.b("404 Not Found"))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_permission_denied_html_view(self):
response = self.client.get('/permission_denied')
response = self.client.get("/permission_denied")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertEqual(response.content, six.b("403 Forbidden"))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
# 2 tests below are based on order of if statements in corresponding method
# of TemplateHTMLRenderer
def test_get_template_names_returns_own_template_name(self):
renderer = TemplateHTMLRenderer()
renderer.template_name = 'test_template'
renderer.template_name = "test_template"
template_name = renderer.get_template_names(self.mock_response, view={})
assert template_name == ['test_template']
assert template_name == ["test_template"]
def test_get_template_names_returns_view_template_name(self):
renderer = TemplateHTMLRenderer()
@ -110,18 +111,16 @@ class TemplateHTMLRendererTests(TestCase):
class MockView(object):
def get_template_names(self):
return ['template from get_template_names method']
return ["template from get_template_names method"]
class MockView2(object):
template_name = 'template from template_name attribute'
template_name = "template from template_name attribute"
template_name = renderer.get_template_names(self.mock_response,
MockView())
assert template_name == ['template from get_template_names method']
template_name = renderer.get_template_names(self.mock_response, MockView())
assert template_name == ["template from get_template_names method"]
template_name = renderer.get_template_names(self.mock_response,
MockView2())
assert template_name == ['template from template_name attribute']
template_name = renderer.get_template_names(self.mock_response, MockView2())
assert template_name == ["template from template_name attribute"]
def test_get_template_names_raises_error_if_no_template_found(self):
renderer = TemplateHTMLRenderer()
@ -129,7 +128,7 @@ class TemplateHTMLRendererTests(TestCase):
renderer.get_template_names(self.mock_response, view=object())
@override_settings(ROOT_URLCONF='tests.test_htmlrenderer')
@override_settings(ROOT_URLCONF="tests.test_htmlrenderer")
class TemplateHTMLRendererExceptionTests(TestCase):
def setUp(self):
"""
@ -138,10 +137,10 @@ class TemplateHTMLRendererExceptionTests(TestCase):
self.get_template = django.template.loader.get_template
def get_template(template_name):
if template_name == '404.html':
return engines['django'].from_string("404: {{ detail }}")
if template_name == '403.html':
return engines['django'].from_string("403: {{ detail }}")
if template_name == "404.html":
return engines["django"].from_string("404: {{ detail }}")
if template_name == "403.html":
return engines["django"].from_string("403: {{ detail }}")
raise TemplateDoesNotExist(template_name)
django.template.loader.get_template = get_template
@ -153,15 +152,18 @@ class TemplateHTMLRendererExceptionTests(TestCase):
django.template.loader.get_template = self.get_template
def test_not_found_html_view_with_template(self):
response = self.client.get('/not_found')
response = self.client.get("/not_found")
self.assertEqual(response.status_code, status.HTTP_404_NOT_FOUND)
self.assertTrue(response.content in (
six.b("404: Not found"), six.b("404 Not Found")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
self.assertTrue(
response.content in (six.b("404: Not found"), six.b("404 Not Found"))
)
self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")
def test_permission_denied_html_view_with_template(self):
response = self.client.get('/permission_denied')
response = self.client.get("/permission_denied")
self.assertEqual(response.status_code, status.HTTP_403_FORBIDDEN)
self.assertTrue(response.content in (
six.b("403: Permission denied"), six.b("403 Forbidden")))
self.assertEqual(response['Content-Type'], 'text/html; charset=utf-8')
self.assertTrue(
response.content
in (six.b("403: Permission denied"), six.b("403 Forbidden"))
)
self.assertEqual(response["Content-Type"], "text/html; charset=utf-8")

View File

@ -6,6 +6,7 @@ from rest_framework import serializers
from rest_framework.renderers import JSONRenderer
from rest_framework.templatetags.rest_framework import format_value
str_called = False
@ -15,35 +16,33 @@ class Example(models.Model):
def __str__(self):
global str_called
str_called = True
return 'An example'
return "An example"
class ExampleSerializer(serializers.HyperlinkedModelSerializer):
class Meta:
model = Example
fields = ('url', 'id', 'text')
fields = ("url", "id", "text")
def dummy_view(request):
pass
urlpatterns = [
url(r'^example/(?P<pk>[0-9]+)/$', dummy_view, name='example-detail'),
]
urlpatterns = [url(r"^example/(?P<pk>[0-9]+)/$", dummy_view, name="example-detail")]
@override_settings(ROOT_URLCONF='tests.test_lazy_hyperlinks')
@override_settings(ROOT_URLCONF="tests.test_lazy_hyperlinks")
class TestLazyHyperlinkNames(TestCase):
def setUp(self):
self.example = Example.objects.create(text='foo')
self.example = Example.objects.create(text="foo")
def test_lazy_hyperlink_names(self):
global str_called
context = {'request': None}
context = {"request": None}
serializer = ExampleSerializer(self.example, context=context)
JSONRenderer().render(serializer.data)
assert not str_called
hyperlink_string = format_value(serializer.data['url'])
assert hyperlink_string == '<a href=/example/1/>An example</a>'
hyperlink_string = format_value(serializer.data["url"])
assert hyperlink_string == "<a href=/example/1/>An example</a>"
assert str_called

View File

@ -5,19 +5,17 @@ from django.core.validators import MaxValueValidator, MinValueValidator
from django.db import models
from django.test import TestCase
from rest_framework import (
exceptions, metadata, serializers, status, versioning, views
)
from rest_framework import exceptions, metadata, serializers, status, versioning, views
from rest_framework.renderers import BrowsableAPIRenderer
from rest_framework.test import APIRequestFactory
from .models import BasicModel
request = APIRequestFactory().options('/')
request = APIRequestFactory().options("/")
class TestMetadata:
def test_determine_metadata_abstract_method_raises_proper_error(self):
with pytest.raises(NotImplementedError):
metadata.BaseMetadata().determine_metadata(None, None)
@ -26,24 +24,23 @@ class TestMetadata:
"""
OPTIONS requests to views should return a valid 200 response.
"""
class ExampleView(views.APIView):
"""Example view."""
pass
view = ExampleView.as_view()
response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
'renders': [
'application/json',
'text/html'
"name": "Example",
"description": "Example view.",
"renders": ["application/json", "text/html"],
"parses": [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
]
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
@ -53,41 +50,40 @@ class TestMetadata:
OPTIONS requests to views where `metadata_class = None` should raise
a MethodNotAllowed exception, which will result in an HTTP 405 response.
"""
class ExampleView(views.APIView):
metadata_class = None
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_405_METHOD_NOT_ALLOWED
assert response.data == {'detail': 'Method "OPTIONS" not allowed.'}
assert response.data == {"detail": 'Method "OPTIONS" not allowed.'}
def test_actions(self):
"""
On generic views OPTIONS should return an 'actions' key with metadata
on the fields that may be supplied to PUT and POST requests.
"""
class NestedField(serializers.Serializer):
a = serializers.IntegerField()
b = serializers.IntegerField()
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
integer_field = serializers.IntegerField(
min_value=1, max_value=1000
)
choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField(min_value=1, max_value=1000)
char_field = serializers.CharField(
required=False, min_length=3, max_length=40
)
list_field = serializers.ListField(
child=serializers.ListField(
child=serializers.IntegerField()
)
child=serializers.ListField(child=serializers.IntegerField())
)
nested_field = NestedField()
uuid_field = serializers.UUIDField(label="UUID field")
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -97,91 +93,87 @@ class TestMetadata:
view = ExampleView.as_view()
response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
'renders': [
'application/json',
'text/html'
"name": "Example",
"description": "Example view.",
"renders": ["application/json", "text/html"],
"parses": [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
],
'actions': {
'POST': {
'choice_field': {
'type': 'choice',
'required': True,
'read_only': False,
'label': 'Choice field',
'choices': [
{'display_name': 'red', 'value': 'red'},
{'display_name': 'green', 'value': 'green'},
{'display_name': 'blue', 'value': 'blue'}
]
"actions": {
"POST": {
"choice_field": {
"type": "choice",
"required": True,
"read_only": False,
"label": "Choice field",
"choices": [
{"display_name": "red", "value": "red"},
{"display_name": "green", "value": "green"},
{"display_name": "blue", "value": "blue"},
],
},
'integer_field': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'Integer field',
'min_value': 1,
'max_value': 1000,
"integer_field": {
"type": "integer",
"required": True,
"read_only": False,
"label": "Integer field",
"min_value": 1,
"max_value": 1000,
},
'char_field': {
'type': 'string',
'required': False,
'read_only': False,
'label': 'Char field',
'min_length': 3,
'max_length': 40
"char_field": {
"type": "string",
"required": False,
"read_only": False,
"label": "Char field",
"min_length": 3,
"max_length": 40,
},
'list_field': {
'type': 'list',
'required': True,
'read_only': False,
'label': 'List field',
'child': {
'type': 'list',
'required': True,
'read_only': False,
'child': {
'type': 'integer',
'required': True,
'read_only': False
}
}
},
'nested_field': {
'type': 'nested object',
'required': True,
'read_only': False,
'label': 'Nested field',
'children': {
'a': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'A'
"list_field": {
"type": "list",
"required": True,
"read_only": False,
"label": "List field",
"child": {
"type": "list",
"required": True,
"read_only": False,
"child": {
"type": "integer",
"required": True,
"read_only": False,
},
'b': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'B'
}
}
},
},
'uuid_field': {
"nested_field": {
"type": "nested object",
"required": True,
"read_only": False,
"label": "Nested field",
"children": {
"a": {
"type": "integer",
"required": True,
"read_only": False,
"label": "A",
},
"b": {
"type": "integer",
"required": True,
"read_only": False,
"label": "B",
},
},
},
"uuid_field": {
"type": "string",
"required": True,
"read_only": False,
"label": "UUID field",
},
}
}
},
}
assert response.status_code == status.HTTP_200_OK
assert response.data == expected
@ -191,13 +183,15 @@ class TestMetadata:
If a user does not have global permissions on an action, then any
metadata associated with it should not be included in OPTION responses.
"""
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -208,26 +202,28 @@ class TestMetadata:
return ExampleSerializer()
def check_permissions(self, request):
if request.method == 'POST':
if request.method == "POST":
raise exceptions.PermissionDenied()
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions']) == ['PUT']
assert list(response.data["actions"]) == ["PUT"]
def test_object_permissions(self):
"""
If a user does not have object permissions on an action, then any
metadata associated with it should not be included in OPTION responses.
"""
class ExampleSerializer(serializers.Serializer):
choice_field = serializers.ChoiceField(['red', 'green', 'blue'])
choice_field = serializers.ChoiceField(["red", "green", "blue"])
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -238,13 +234,13 @@ class TestMetadata:
return ExampleSerializer()
def get_object(self):
if self.request.method == 'PUT':
if self.request.method == "PUT":
raise exceptions.PermissionDenied()
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_200_OK
assert list(response.data['actions'].keys()) == ['POST']
assert list(response.data["actions"].keys()) == ["POST"]
def test_bug_2455_clone_request(self):
class ExampleView(views.APIView):
@ -254,7 +250,7 @@ class TestMetadata:
pass
def get_serializer(self):
assert hasattr(self.request, 'version')
assert hasattr(self.request, "version")
return serializers.Serializer()
view = ExampleView.as_view()
@ -268,7 +264,7 @@ class TestMetadata:
pass
def get_serializer(self):
assert hasattr(self.request, 'versioning_scheme')
assert hasattr(self.request, "versioning_scheme")
return serializers.Serializer()
scheme = versioning.QueryParameterVersioning
@ -279,12 +275,14 @@ class TestMetadata:
"""
HiddenField shouldn't show up in SimpleMetadata at all.
"""
class ExampleSerializer(serializers.Serializer):
integer_field = serializers.IntegerField(max_value=10)
hidden_field = serializers.HiddenField(default=1)
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -294,9 +292,11 @@ class TestMetadata:
view = ExampleView.as_view()
response = view(request=request)
assert response.status_code == status.HTTP_200_OK
assert set(response.data['actions']['POST'].keys()) == {'integer_field'}
assert set(response.data["actions"]["POST"].keys()) == {"integer_field"}
def test_list_serializer_metadata_returns_info_about_fields_of_child_serializer(self):
def test_list_serializer_metadata_returns_info_about_fields_of_child_serializer(
self
):
class ExampleSerializer(serializers.Serializer):
integer_field = serializers.IntegerField(max_value=10)
char_field = serializers.CharField(required=False)
@ -307,14 +307,16 @@ class TestMetadata:
options = metadata.SimpleMetadata()
child_serializer = ExampleSerializer()
list_serializer = ExampleListSerializer(child=child_serializer)
assert options.get_serializer_info(list_serializer) == options.get_serializer_info(child_serializer)
assert options.get_serializer_info(
list_serializer
) == options.get_serializer_info(child_serializer)
class TestSimpleMetadataFieldInfo(TestCase):
def test_null_boolean_field_info_type(self):
options = metadata.SimpleMetadata()
field_info = options.get_field_info(serializers.NullBooleanField())
assert field_info['type'] == 'boolean'
assert field_info["type"] == "boolean"
def test_related_field_choices(self):
options = metadata.SimpleMetadata()
@ -323,7 +325,7 @@ class TestSimpleMetadataFieldInfo(TestCase):
field_info = options.get_field_info(
serializers.RelatedField(queryset=BasicModel.objects.all())
)
assert 'choices' not in field_info
assert "choices" not in field_info
class TestModelSerializerMetadata(TestCase):
@ -333,9 +335,12 @@ class TestModelSerializerMetadata(TestCase):
on the fields that may be supplied to PUT and POST requests. It should
not fail when a read_only PrimaryKeyRelatedField is present
"""
class Parent(models.Model):
integer_field = models.IntegerField(validators=[MinValueValidator(1), MaxValueValidator(1000)])
children = models.ManyToManyField('Child')
integer_field = models.IntegerField(
validators=[MinValueValidator(1), MaxValueValidator(1000)]
)
children = models.ManyToManyField("Child")
name = models.CharField(max_length=100, blank=True, null=True)
class Child(models.Model):
@ -346,10 +351,11 @@ class TestModelSerializerMetadata(TestCase):
class Meta:
model = Parent
fields = '__all__'
fields = "__all__"
class ExampleView(views.APIView):
"""Example view."""
def post(self, request):
pass
@ -359,48 +365,45 @@ class TestModelSerializerMetadata(TestCase):
view = ExampleView.as_view()
response = view(request=request)
expected = {
'name': 'Example',
'description': 'Example view.',
'renders': [
'application/json',
'text/html'
"name": "Example",
"description": "Example view.",
"renders": ["application/json", "text/html"],
"parses": [
"application/json",
"application/x-www-form-urlencoded",
"multipart/form-data",
],
'parses': [
'application/json',
'application/x-www-form-urlencoded',
'multipart/form-data'
],
'actions': {
'POST': {
'id': {
'type': 'integer',
'required': False,
'read_only': True,
'label': 'ID'
"actions": {
"POST": {
"id": {
"type": "integer",
"required": False,
"read_only": True,
"label": "ID",
},
'children': {
'type': 'field',
'required': False,
'read_only': True,
'label': 'Children'
"children": {
"type": "field",
"required": False,
"read_only": True,
"label": "Children",
},
'integer_field': {
'type': 'integer',
'required': True,
'read_only': False,
'label': 'Integer field',
'min_value': 1,
'max_value': 1000
"integer_field": {
"type": "integer",
"required": True,
"read_only": False,
"label": "Integer field",
"min_value": 1,
"max_value": 1000,
},
"name": {
"type": "string",
"required": False,
"read_only": False,
"label": "Name",
"max_length": 100,
},
'name': {
'type': 'string',
'required': False,
'read_only': False,
'label': 'Name',
'max_length': 100
}
}
}
},
}
assert response.status_code == status.HTTP_200_OK

View File

@ -17,8 +17,8 @@ class PostView(APIView):
urlpatterns = [
url(r'^auth$', APIView.as_view(authentication_classes=(TokenAuthentication,))),
url(r'^post$', PostView.as_view()),
url(r"^auth$", APIView.as_view(authentication_classes=(TokenAuthentication,))),
url(r"^post$", PostView.as_view()),
]
@ -28,8 +28,8 @@ class RequestUserMiddleware(object):
def __call__(self, request):
response = self.get_response(request)
assert hasattr(request, 'user'), '`user` is not set on request'
assert request.user.is_authenticated, '`user` is not authenticated'
assert hasattr(request, "user"), "`user` is not set on request"
assert request.user.is_authenticated, "`user` is not authenticated"
return response
@ -49,28 +49,27 @@ class RequestPOSTMiddleware(object):
# Ensure request.POST is set as appropriate
if is_form_media_type(request.content_type):
assert request.POST == {'foo': ['bar']}
assert request.POST == {"foo": ["bar"]}
else:
assert request.POST == {}
return response
@override_settings(ROOT_URLCONF='tests.test_middleware')
@override_settings(ROOT_URLCONF="tests.test_middleware")
class TestMiddleware(APITestCase):
@override_settings(MIDDLEWARE=('tests.test_middleware.RequestUserMiddleware',))
@override_settings(MIDDLEWARE=("tests.test_middleware.RequestUserMiddleware",))
def test_middleware_can_access_user_when_processing_response(self):
user = User.objects.create_user('john', 'john@example.com', 'password')
key = 'abcd1234'
user = User.objects.create_user("john", "john@example.com", "password")
key = "abcd1234"
Token.objects.create(key=key, user=user)
self.client.get('/auth', HTTP_AUTHORIZATION='Token %s' % key)
self.client.get("/auth", HTTP_AUTHORIZATION="Token %s" % key)
@override_settings(MIDDLEWARE=('tests.test_middleware.RequestPOSTMiddleware',))
@override_settings(MIDDLEWARE=("tests.test_middleware.RequestPOSTMiddleware",))
def test_middleware_can_access_request_post_when_processing_response(self):
response = self.client.post('/post', {'foo': 'bar'})
response = self.client.post("/post", {"foo": "bar"})
assert response.status_code == 200
response = self.client.post('/post', {'foo': 'bar'}, format='json')
response = self.client.post("/post", {"foo": "bar"}, format="json")
assert response.status_code == 200

File diff suppressed because it is too large Load Diff

View File

@ -25,45 +25,41 @@ class AssociatedModel(RESTFrameworkModel):
class DerivedModelSerializer(serializers.ModelSerializer):
class Meta:
model = ChildModel
fields = '__all__'
fields = "__all__"
class AssociatedModelSerializer(serializers.ModelSerializer):
class Meta:
model = AssociatedModel
fields = '__all__'
fields = "__all__"
# Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields
serialized fields
"""
child = ChildModel(name1='parent name', name2='child name')
child = ChildModel(name1="parent name", name2="child name")
serializer = DerivedModelSerializer(child)
assert set(serializer.data) == {'name1', 'name2', 'id'}
assert set(serializer.data) == {"name1", "name2", "id"}
def test_onetoone_primary_key_model_fields_as_expected(self):
"""
Assert that a model with a onetoone field that is the primary key is
not treated like a derived model
"""
parent = ParentModel.objects.create(name1='parent name')
associate = AssociatedModel.objects.create(name='hello', ref=parent)
parent = ParentModel.objects.create(name1="parent name")
associate = AssociatedModel.objects.create(name="hello", ref=parent)
serializer = AssociatedModelSerializer(associate)
assert set(serializer.data) == {'name', 'ref'}
assert set(serializer.data) == {"name", "ref"}
def test_data_is_valid_without_parent_ptr(self):
"""
Assert that the pointer to the parent table is not a required field
for input data
"""
data = {
'name1': 'parent name',
'name2': 'child name',
}
data = {"name1": "parent name", "name2": "child name"}
serializer = DerivedModelSerializer(data=data)
assert serializer.is_valid() is True

View File

@ -4,32 +4,31 @@ import pytest
from django.http import Http404
from django.test import TestCase
from rest_framework.negotiation import (
BaseContentNegotiation, DefaultContentNegotiation
)
from rest_framework.negotiation import BaseContentNegotiation, DefaultContentNegotiation
from rest_framework.renderers import BaseRenderer
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
from rest_framework.utils.mediatypes import _MediaType
factory = APIRequestFactory()
class MockOpenAPIRenderer(BaseRenderer):
media_type = 'application/openapi+json;version=2.0'
format = 'swagger'
media_type = "application/openapi+json;version=2.0"
format = "swagger"
class MockJSONRenderer(BaseRenderer):
media_type = 'application/json'
media_type = "application/json"
class MockHTMLRenderer(BaseRenderer):
media_type = 'text/html'
media_type = "text/html"
class NoCharsetSpecifiedRenderer(BaseRenderer):
media_type = 'my/media'
media_type = "my/media"
class TestAcceptedMediaType(TestCase):
@ -41,54 +40,56 @@ class TestAcceptedMediaType(TestCase):
return self.negotiator.select_renderer(request, self.renderers)
def test_client_without_accept_use_renderer(self):
request = Request(factory.get('/'))
request = Request(factory.get("/"))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json'
assert accepted_media_type == "application/json"
def test_client_underspecifies_accept_use_renderer(self):
request = Request(factory.get('/', HTTP_ACCEPT='*/*'))
request = Request(factory.get("/", HTTP_ACCEPT="*/*"))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json'
assert accepted_media_type == "application/json"
def test_client_overspecifies_accept_use_client(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/json; indent=8'))
request = Request(factory.get("/", HTTP_ACCEPT="application/json; indent=8"))
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/json; indent=8'
assert accepted_media_type == "application/json; indent=8"
def test_client_specifies_parameter(self):
request = Request(factory.get('/', HTTP_ACCEPT='application/openapi+json;version=2.0'))
request = Request(
factory.get("/", HTTP_ACCEPT="application/openapi+json;version=2.0")
)
accepted_renderer, accepted_media_type = self.select_renderer(request)
assert accepted_media_type == 'application/openapi+json;version=2.0'
assert accepted_renderer.format == 'swagger'
assert accepted_media_type == "application/openapi+json;version=2.0"
assert accepted_renderer.format == "swagger"
def test_match_is_false_if_main_types_not_match(self):
mediatype = _MediaType('test_1')
anoter_mediatype = _MediaType('test_2')
mediatype = _MediaType("test_1")
anoter_mediatype = _MediaType("test_2")
assert mediatype.match(anoter_mediatype) is False
def test_mediatype_match_is_false_if_keys_not_match(self):
mediatype = _MediaType(';test_param=foo')
another_mediatype = _MediaType(';test_param=bar')
mediatype = _MediaType(";test_param=foo")
another_mediatype = _MediaType(";test_param=bar")
assert mediatype.match(another_mediatype) is False
def test_mediatype_precedence_with_wildcard_subtype(self):
mediatype = _MediaType('test/*')
mediatype = _MediaType("test/*")
assert mediatype.precedence == 1
def test_mediatype_string_representation(self):
mediatype = _MediaType('test/*; foo=bar')
assert str(mediatype) == 'test/*; foo=bar'
mediatype = _MediaType("test/*; foo=bar")
assert str(mediatype) == "test/*; foo=bar"
def test_raise_error_if_no_suitable_renderers_found(self):
class MockRenderer(object):
format = 'xml'
format = "xml"
renderers = [MockRenderer()]
with pytest.raises(Http404):
self.negotiator.filter_renderers(renderers, format='json')
self.negotiator.filter_renderers(renderers, format="json")
class BaseContentNegotiationTests(TestCase):
def setUp(self):
self.negotiator = BaseContentNegotiation()

View File

@ -3,9 +3,9 @@ from __future__ import unicode_literals
from django.db import models
from django.test import TestCase
# Models
from rest_framework import serializers
from tests.models import RESTFrameworkModel
# Models
from tests.test_multitable_inheritance import ChildModel
@ -19,25 +19,24 @@ class ChildAssociatedModel(RESTFrameworkModel):
class DerivedModelSerializer(serializers.ModelSerializer):
class Meta:
model = ChildModel
fields = ['id', 'name1', 'name2', 'childassociatedmodel']
fields = ["id", "name1", "name2", "childassociatedmodel"]
class ChildAssociatedModelSerializer(serializers.ModelSerializer):
class Meta:
model = ChildAssociatedModel
fields = ['id', 'child_name']
fields = ["id", "child_name"]
# Tests
class InheritedModelSerializationTests(TestCase):
def test_multitable_inherited_model_fields_as_expected(self):
"""
Assert that the parent pointer field is not included in the fields
serialized fields
"""
child = ChildModel(name1='parent name', name2='child name')
child = ChildModel(name1="parent name", name2="child name")
serializer = DerivedModelSerializer(child)
self.assertEqual(set(serializer.data),
{'name1', 'name2', 'id', 'childassociatedmodel'})
self.assertEqual(
set(serializer.data), {"name1", "name2", "id", "childassociatedmodel"}
)

View File

@ -8,12 +8,18 @@ from django.test import TestCase
from django.utils import six
from rest_framework import (
exceptions, filters, generics, pagination, serializers, status
exceptions,
filters,
generics,
pagination,
serializers,
status,
)
from rest_framework.pagination import PAGE_BREAK, PageLink
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
factory = APIRequestFactory()
@ -33,39 +39,39 @@ class TestPaginationIntegration:
class BasicPagination(pagination.PageNumberPagination):
page_size = 5
page_size_query_param = 'page_size'
page_size_query_param = "page_size"
max_page_size = 20
self.view = generics.ListAPIView.as_view(
serializer_class=PassThroughSerializer,
queryset=range(1, 101),
filter_backends=[EvenItemsOnly],
pagination_class=BasicPagination
pagination_class=BasicPagination,
)
def test_filtered_items_are_paginated(self):
request = factory.get('/', {'page': 2})
request = factory.get("/", {"page": 2})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [12, 14, 16, 18, 20],
'previous': 'http://testserver/',
'next': 'http://testserver/?page=3',
'count': 50
"results": [12, 14, 16, 18, 20],
"previous": "http://testserver/",
"next": "http://testserver/?page=3",
"count": 50,
}
def test_setting_page_size(self):
"""
When 'paginate_by_param' is set, the client may choose a page size.
"""
request = factory.get('/', {'page_size': 10})
request = factory.get("/", {"page_size": 10})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
'previous': None,
'next': 'http://testserver/?page=2&page_size=10',
'count': 50
"results": [2, 4, 6, 8, 10, 12, 14, 16, 18, 20],
"previous": None,
"next": "http://testserver/?page=2&page_size=10",
"count": 50,
}
def test_setting_page_size_over_maximum(self):
@ -73,70 +79,84 @@ class TestPaginationIntegration:
When page_size parameter exceeds maximum allowable,
then it should be capped to the maximum.
"""
request = factory.get('/', {'page_size': 1000})
request = factory.get("/", {"page_size": 1000})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [
2, 4, 6, 8, 10, 12, 14, 16, 18, 20,
22, 24, 26, 28, 30, 32, 34, 36, 38, 40
"results": [
2,
4,
6,
8,
10,
12,
14,
16,
18,
20,
22,
24,
26,
28,
30,
32,
34,
36,
38,
40,
],
'previous': None,
'next': 'http://testserver/?page=2&page_size=1000',
'count': 50
"previous": None,
"next": "http://testserver/?page=2&page_size=1000",
"count": 50,
}
def test_setting_page_size_to_zero(self):
"""
When page_size parameter is invalid it should return to the default.
"""
request = factory.get('/', {'page_size': 0})
request = factory.get("/", {"page_size": 0})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [2, 4, 6, 8, 10],
'previous': None,
'next': 'http://testserver/?page=2&page_size=0',
'count': 50
"results": [2, 4, 6, 8, 10],
"previous": None,
"next": "http://testserver/?page=2&page_size=0",
"count": 50,
}
def test_additional_query_params_are_preserved(self):
request = factory.get('/', {'page': 2, 'filter': 'even'})
request = factory.get("/", {"page": 2, "filter": "even"})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [12, 14, 16, 18, 20],
'previous': 'http://testserver/?filter=even',
'next': 'http://testserver/?filter=even&page=3',
'count': 50
"results": [12, 14, 16, 18, 20],
"previous": "http://testserver/?filter=even",
"next": "http://testserver/?filter=even&page=3",
"count": 50,
}
def test_empty_query_params_are_preserved(self):
request = factory.get('/', {'page': 2, 'filter': ''})
request = factory.get("/", {"page": 2, "filter": ""})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == {
'results': [12, 14, 16, 18, 20],
'previous': 'http://testserver/?filter=',
'next': 'http://testserver/?filter=&page=3',
'count': 50
"results": [12, 14, 16, 18, 20],
"previous": "http://testserver/?filter=",
"next": "http://testserver/?filter=&page=3",
"count": 50,
}
def test_404_not_found_for_zero_page(self):
request = factory.get('/', {'page': '0'})
request = factory.get("/", {"page": "0"})
response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {
'detail': 'Invalid page.'
}
assert response.data == {"detail": "Invalid page."}
def test_404_not_found_for_invalid_page(self):
request = factory.get('/', {'page': 'invalid'})
request = factory.get("/", {"page": "invalid"})
response = self.view(request)
assert response.status_code == status.HTTP_404_NOT_FOUND
assert response.data == {
'detail': 'Invalid page.'
}
assert response.data == {"detail": "Invalid page."}
class TestPaginationDisabledIntegration:
@ -152,11 +172,11 @@ class TestPaginationDisabledIntegration:
self.view = generics.ListAPIView.as_view(
serializer_class=PassThroughSerializer,
queryset=range(1, 101),
pagination_class=None
pagination_class=None,
)
def test_unpaginated_list(self):
request = factory.get('/', {'page': 2})
request = factory.get("/", {"page": 2})
response = self.view(request)
assert response.status_code == status.HTTP_200_OK
assert response.data == list(range(1, 101))
@ -185,81 +205,81 @@ class TestPageNumberPagination:
return self.pagination.get_html_context()
def test_no_page_number(self):
request = Request(factory.get('/'))
request = Request(factory.get("/"))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [1, 2, 3, 4, 5]
assert content == {
'results': [1, 2, 3, 4, 5],
'previous': None,
'next': 'http://testserver/?page=2',
'count': 100
"results": [1, 2, 3, 4, 5],
"previous": None,
"next": "http://testserver/?page=2",
"count": 100,
}
assert context == {
'previous_url': None,
'next_url': 'http://testserver/?page=2',
'page_links': [
PageLink('http://testserver/', 1, True, False),
PageLink('http://testserver/?page=2', 2, False, False),
PageLink('http://testserver/?page=3', 3, False, False),
"previous_url": None,
"next_url": "http://testserver/?page=2",
"page_links": [
PageLink("http://testserver/", 1, True, False),
PageLink("http://testserver/?page=2", 2, False, False),
PageLink("http://testserver/?page=3", 3, False, False),
PAGE_BREAK,
PageLink('http://testserver/?page=20', 20, False, False),
]
PageLink("http://testserver/?page=20", 20, False, False),
],
}
assert self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type)
def test_second_page(self):
request = Request(factory.get('/', {'page': 2}))
request = Request(factory.get("/", {"page": 2}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [6, 7, 8, 9, 10]
assert content == {
'results': [6, 7, 8, 9, 10],
'previous': 'http://testserver/',
'next': 'http://testserver/?page=3',
'count': 100
"results": [6, 7, 8, 9, 10],
"previous": "http://testserver/",
"next": "http://testserver/?page=3",
"count": 100,
}
assert context == {
'previous_url': 'http://testserver/',
'next_url': 'http://testserver/?page=3',
'page_links': [
PageLink('http://testserver/', 1, False, False),
PageLink('http://testserver/?page=2', 2, True, False),
PageLink('http://testserver/?page=3', 3, False, False),
"previous_url": "http://testserver/",
"next_url": "http://testserver/?page=3",
"page_links": [
PageLink("http://testserver/", 1, False, False),
PageLink("http://testserver/?page=2", 2, True, False),
PageLink("http://testserver/?page=3", 3, False, False),
PAGE_BREAK,
PageLink('http://testserver/?page=20', 20, False, False),
]
PageLink("http://testserver/?page=20", 20, False, False),
],
}
def test_last_page(self):
request = Request(factory.get('/', {'page': 'last'}))
request = Request(factory.get("/", {"page": "last"}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [96, 97, 98, 99, 100]
assert content == {
'results': [96, 97, 98, 99, 100],
'previous': 'http://testserver/?page=19',
'next': None,
'count': 100
"results": [96, 97, 98, 99, 100],
"previous": "http://testserver/?page=19",
"next": None,
"count": 100,
}
assert context == {
'previous_url': 'http://testserver/?page=19',
'next_url': None,
'page_links': [
PageLink('http://testserver/', 1, False, False),
"previous_url": "http://testserver/?page=19",
"next_url": None,
"page_links": [
PageLink("http://testserver/", 1, False, False),
PAGE_BREAK,
PageLink('http://testserver/?page=18', 18, False, False),
PageLink('http://testserver/?page=19', 19, False, False),
PageLink('http://testserver/?page=20', 20, True, False),
]
PageLink("http://testserver/?page=18", 18, False, False),
PageLink("http://testserver/?page=19", 19, False, False),
PageLink("http://testserver/?page=20", 20, True, False),
],
}
def test_invalid_page(self):
request = Request(factory.get('/', {'page': 'invalid'}))
request = Request(factory.get("/", {"page": "invalid"}))
with pytest.raises(exceptions.NotFound):
self.paginate_queryset(request)
@ -295,29 +315,22 @@ class TestPageNumberPaginationOverride:
return self.pagination.get_html_context()
def test_no_page_number(self):
request = Request(factory.get('/'))
request = Request(factory.get("/"))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [1]
assert content == {
'results': [1, ],
'previous': None,
'next': None,
'count': 1
}
assert content == {"results": [1], "previous": None, "next": None, "count": 1}
assert context == {
'previous_url': None,
'next_url': None,
'page_links': [
PageLink('http://testserver/', 1, True, False),
]
"previous_url": None,
"next_url": None,
"page_links": [PageLink("http://testserver/", 1, True, False)],
}
assert not self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type)
def test_invalid_page(self):
request = Request(factory.get('/', {'page': 'invalid'}))
request = Request(factory.get("/", {"page": "invalid"}))
with pytest.raises(exceptions.NotFound):
self.paginate_queryset(request)
@ -346,27 +359,27 @@ class TestLimitOffset:
return self.pagination.get_html_context()
def test_no_offset(self):
request = Request(factory.get('/', {'limit': 5}))
request = Request(factory.get("/", {"limit": 5}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [1, 2, 3, 4, 5]
assert content == {
'results': [1, 2, 3, 4, 5],
'previous': None,
'next': 'http://testserver/?limit=5&offset=5',
'count': 100
"results": [1, 2, 3, 4, 5],
"previous": None,
"next": "http://testserver/?limit=5&offset=5",
"count": 100,
}
assert context == {
'previous_url': None,
'next_url': 'http://testserver/?limit=5&offset=5',
'page_links': [
PageLink('http://testserver/?limit=5', 1, True, False),
PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
"previous_url": None,
"next_url": "http://testserver/?limit=5&offset=5",
"page_links": [
PageLink("http://testserver/?limit=5", 1, True, False),
PageLink("http://testserver/?limit=5&offset=5", 2, False, False),
PageLink("http://testserver/?limit=5&offset=10", 3, False, False),
PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
]
PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
],
}
assert self.pagination.display_page_controls
assert isinstance(self.pagination.to_html(), six.text_type)
@ -374,7 +387,8 @@ class TestLimitOffset:
def test_pagination_not_applied_if_limit_or_default_limit_not_set(self):
class MockPagination(pagination.LimitOffsetPagination):
default_limit = None
request = Request(factory.get('/'))
request = Request(factory.get("/"))
queryset = MockPagination().paginate_queryset(self.queryset, request)
assert queryset is None
@ -384,104 +398,104 @@ class TestLimitOffset:
* The first page should still be offset zero.
* We may end up displaying an extra page in the pagination control.
"""
request = Request(factory.get('/', {'limit': 5, 'offset': 1}))
request = Request(factory.get("/", {"limit": 5, "offset": 1}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [2, 3, 4, 5, 6]
assert content == {
'results': [2, 3, 4, 5, 6],
'previous': 'http://testserver/?limit=5',
'next': 'http://testserver/?limit=5&offset=6',
'count': 100
"results": [2, 3, 4, 5, 6],
"previous": "http://testserver/?limit=5",
"next": "http://testserver/?limit=5&offset=6",
"count": 100,
}
assert context == {
'previous_url': 'http://testserver/?limit=5',
'next_url': 'http://testserver/?limit=5&offset=6',
'page_links': [
PageLink('http://testserver/?limit=5', 1, False, False),
PageLink('http://testserver/?limit=5&offset=1', 2, True, False),
PageLink('http://testserver/?limit=5&offset=6', 3, False, False),
"previous_url": "http://testserver/?limit=5",
"next_url": "http://testserver/?limit=5&offset=6",
"page_links": [
PageLink("http://testserver/?limit=5", 1, False, False),
PageLink("http://testserver/?limit=5&offset=1", 2, True, False),
PageLink("http://testserver/?limit=5&offset=6", 3, False, False),
PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=96', 21, False, False),
]
PageLink("http://testserver/?limit=5&offset=96", 21, False, False),
],
}
def test_first_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 5}))
request = Request(factory.get("/", {"limit": 5, "offset": 5}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [6, 7, 8, 9, 10]
assert content == {
'results': [6, 7, 8, 9, 10],
'previous': 'http://testserver/?limit=5',
'next': 'http://testserver/?limit=5&offset=10',
'count': 100
"results": [6, 7, 8, 9, 10],
"previous": "http://testserver/?limit=5",
"next": "http://testserver/?limit=5&offset=10",
"count": 100,
}
assert context == {
'previous_url': 'http://testserver/?limit=5',
'next_url': 'http://testserver/?limit=5&offset=10',
'page_links': [
PageLink('http://testserver/?limit=5', 1, False, False),
PageLink('http://testserver/?limit=5&offset=5', 2, True, False),
PageLink('http://testserver/?limit=5&offset=10', 3, False, False),
"previous_url": "http://testserver/?limit=5",
"next_url": "http://testserver/?limit=5&offset=10",
"page_links": [
PageLink("http://testserver/?limit=5", 1, False, False),
PageLink("http://testserver/?limit=5&offset=5", 2, True, False),
PageLink("http://testserver/?limit=5&offset=10", 3, False, False),
PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
]
PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
],
}
def test_middle_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 10}))
request = Request(factory.get("/", {"limit": 5, "offset": 10}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [11, 12, 13, 14, 15]
assert content == {
'results': [11, 12, 13, 14, 15],
'previous': 'http://testserver/?limit=5&offset=5',
'next': 'http://testserver/?limit=5&offset=15',
'count': 100
"results": [11, 12, 13, 14, 15],
"previous": "http://testserver/?limit=5&offset=5",
"next": "http://testserver/?limit=5&offset=15",
"count": 100,
}
assert context == {
'previous_url': 'http://testserver/?limit=5&offset=5',
'next_url': 'http://testserver/?limit=5&offset=15',
'page_links': [
PageLink('http://testserver/?limit=5', 1, False, False),
PageLink('http://testserver/?limit=5&offset=5', 2, False, False),
PageLink('http://testserver/?limit=5&offset=10', 3, True, False),
PageLink('http://testserver/?limit=5&offset=15', 4, False, False),
"previous_url": "http://testserver/?limit=5&offset=5",
"next_url": "http://testserver/?limit=5&offset=15",
"page_links": [
PageLink("http://testserver/?limit=5", 1, False, False),
PageLink("http://testserver/?limit=5&offset=5", 2, False, False),
PageLink("http://testserver/?limit=5&offset=10", 3, True, False),
PageLink("http://testserver/?limit=5&offset=15", 4, False, False),
PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=95', 20, False, False),
]
PageLink("http://testserver/?limit=5&offset=95", 20, False, False),
],
}
def test_ending_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 95}))
request = Request(factory.get("/", {"limit": 5, "offset": 95}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
context = self.get_html_context()
assert queryset == [96, 97, 98, 99, 100]
assert content == {
'results': [96, 97, 98, 99, 100],
'previous': 'http://testserver/?limit=5&offset=90',
'next': None,
'count': 100
"results": [96, 97, 98, 99, 100],
"previous": "http://testserver/?limit=5&offset=90",
"next": None,
"count": 100,
}
assert context == {
'previous_url': 'http://testserver/?limit=5&offset=90',
'next_url': None,
'page_links': [
PageLink('http://testserver/?limit=5', 1, False, False),
"previous_url": "http://testserver/?limit=5&offset=90",
"next_url": None,
"page_links": [
PageLink("http://testserver/?limit=5", 1, False, False),
PAGE_BREAK,
PageLink('http://testserver/?limit=5&offset=85', 18, False, False),
PageLink('http://testserver/?limit=5&offset=90', 19, False, False),
PageLink('http://testserver/?limit=5&offset=95', 20, True, False),
]
PageLink("http://testserver/?limit=5&offset=85", 18, False, False),
PageLink("http://testserver/?limit=5&offset=90", 19, False, False),
PageLink("http://testserver/?limit=5&offset=95", 20, True, False),
],
}
def test_erronous_offset(self):
request = Request(factory.get('/', {'limit': 5, 'offset': 1000}))
request = Request(factory.get("/", {"limit": 5, "offset": 1000}))
queryset = self.paginate_queryset(request)
self.get_paginated_content(queryset)
self.get_html_context()
@ -490,7 +504,7 @@ class TestLimitOffset:
"""
An invalid offset query param should be treated as 0.
"""
request = Request(factory.get('/', {'limit': 5, 'offset': 'invalid'}))
request = Request(factory.get("/", {"limit": 5, "offset": "invalid"}))
queryset = self.paginate_queryset(request)
assert queryset == [1, 2, 3, 4, 5]
@ -498,27 +512,31 @@ class TestLimitOffset:
"""
An invalid limit query param should be ignored in favor of the default.
"""
request = Request(factory.get('/', {'limit': 'invalid', 'offset': 0}))
request = Request(factory.get("/", {"limit": "invalid", "offset": 0}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
next_limit = self.pagination.default_limit
next_offset = self.pagination.default_limit
next_url = 'http://testserver/?limit={0}&offset={1}'.format(next_limit, next_offset)
next_url = "http://testserver/?limit={0}&offset={1}".format(
next_limit, next_offset
)
assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
assert content.get('next') == next_url
assert content.get("next") == next_url
def test_zero_limit(self):
"""
An zero limit query param should be ignored in favor of the default.
"""
request = Request(factory.get('/', {'limit': 0, 'offset': 0}))
request = Request(factory.get("/", {"limit": 0, "offset": 0}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
next_limit = self.pagination.default_limit
next_offset = self.pagination.default_limit
next_url = 'http://testserver/?limit={0}&offset={1}'.format(next_limit, next_offset)
next_url = "http://testserver/?limit={0}&offset={1}".format(
next_limit, next_offset
)
assert queryset == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
assert content.get('next') == next_url
assert content.get("next") == next_url
def test_max_limit(self):
"""
@ -526,47 +544,46 @@ class TestLimitOffset:
requested limit is greater than the max_limit
"""
offset = 50
request = Request(factory.get('/', {'limit': '11235', 'offset': offset}))
request = Request(factory.get("/", {"limit": "11235", "offset": offset}))
queryset = self.paginate_queryset(request)
content = self.get_paginated_content(queryset)
max_limit = self.pagination.max_limit
next_offset = offset + max_limit
prev_offset = offset - max_limit
base_url = 'http://testserver/?limit={0}'.format(max_limit)
next_url = base_url + '&offset={0}'.format(next_offset)
prev_url = base_url + '&offset={0}'.format(prev_offset)
base_url = "http://testserver/?limit={0}".format(max_limit)
next_url = base_url + "&offset={0}".format(next_offset)
prev_url = base_url + "&offset={0}".format(prev_offset)
assert queryset == list(range(51, 66))
assert content.get('next') == next_url
assert content.get('previous') == prev_url
assert content.get("next") == next_url
assert content.get("previous") == prev_url
class CursorPaginationTestsMixin:
def test_invalid_cursor(self):
request = Request(factory.get('/', {'cursor': '123'}))
request = Request(factory.get("/", {"cursor": "123"}))
with pytest.raises(exceptions.NotFound):
self.pagination.paginate_queryset(self.queryset, request)
def test_use_with_ordering_filter(self):
class MockView:
filter_backends = (filters.OrderingFilter,)
ordering_fields = ['username', 'created']
ordering = 'created'
ordering_fields = ["username", "created"]
ordering = "created"
request = Request(factory.get('/', {'ordering': 'username'}))
request = Request(factory.get("/", {"ordering": "username"}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('username',)
assert ordering == ("username",)
request = Request(factory.get('/', {'ordering': '-username'}))
request = Request(factory.get("/", {"ordering": "-username"}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('-username',)
assert ordering == ("-username",)
request = Request(factory.get('/', {'ordering': 'invalid'}))
request = Request(factory.get("/", {"ordering": "invalid"}))
ordering = self.pagination.get_ordering(request, [], MockView())
assert ordering == ('created',)
assert ordering == ("created",)
def test_cursor_pagination(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/')
(previous, current, next, previous_url, next_url) = self.get_pages("/")
assert previous is None
assert current == [1, 1, 1, 1, 1]
@ -635,7 +652,9 @@ class CursorPaginationTestsMixin:
assert isinstance(self.pagination.to_html(), six.text_type)
def test_cursor_pagination_with_page_size(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=20')
(previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=20"
)
assert previous is None
assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7]
@ -647,7 +666,9 @@ class CursorPaginationTestsMixin:
assert next is None
def test_cursor_pagination_with_page_size_over_limit(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=30')
(previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=30"
)
assert previous is None
assert current == [1, 1, 1, 1, 1, 1, 2, 3, 4, 4, 4, 4, 5, 6, 7, 7, 7, 7, 7, 7]
@ -659,7 +680,9 @@ class CursorPaginationTestsMixin:
assert next is None
def test_cursor_pagination_with_page_size_zero(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=0')
(previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=0"
)
assert previous is None
assert current == [1, 1, 1, 1, 1]
@ -726,7 +749,9 @@ class CursorPaginationTestsMixin:
assert next == [1, 2, 3, 4, 4]
def test_cursor_pagination_with_page_size_negative(self):
(previous, current, next, previous_url, next_url) = self.get_pages('/?page_size=-5')
(previous, current, next, previous_url, next_url) = self.get_pages(
"/?page_size=-5"
)
assert previous is None
assert current == [1, 1, 1, 1, 1]
@ -809,19 +834,17 @@ class TestCursorPagination(CursorPaginationTestsMixin):
def filter(self, created__gt=None, created__lt=None):
if created__gt is not None:
return MockQuerySet([
item for item in self.items
if item.created > int(created__gt)
])
return MockQuerySet(
[item for item in self.items if item.created > int(created__gt)]
)
assert created__lt is not None
return MockQuerySet([
item for item in self.items
if item.created < int(created__lt)
])
return MockQuerySet(
[item for item in self.items if item.created < int(created__lt)]
)
def order_by(self, *ordering):
if ordering[0].startswith('-'):
if ordering[0].startswith("-"):
return MockQuerySet(list(reversed(self.items)))
return self
@ -830,21 +853,48 @@ class TestCursorPagination(CursorPaginationTestsMixin):
class ExamplePagination(pagination.CursorPagination):
page_size = 5
page_size_query_param = 'page_size'
page_size_query_param = "page_size"
max_page_size = 20
ordering = 'created'
ordering = "created"
self.pagination = ExamplePagination()
self.queryset = MockQuerySet([
MockObject(idx) for idx in [
1, 1, 1, 1, 1,
1, 2, 3, 4, 4,
4, 4, 5, 6, 7,
7, 7, 7, 7, 7,
7, 7, 7, 8, 9,
9, 9, 9, 9, 9
self.queryset = MockQuerySet(
[
MockObject(idx)
for idx in [
1,
1,
1,
1,
1,
1,
2,
3,
4,
4,
4,
4,
5,
6,
7,
7,
7,
7,
7,
7,
7,
7,
7,
8,
9,
9,
9,
9,
9,
9,
]
]
])
)
def get_pages(self, url):
"""
@ -888,18 +938,42 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
def setUp(self):
class ExamplePagination(pagination.CursorPagination):
page_size = 5
page_size_query_param = 'page_size'
page_size_query_param = "page_size"
max_page_size = 20
ordering = 'created'
ordering = "created"
self.pagination = ExamplePagination()
data = [
1, 1, 1, 1, 1,
1, 2, 3, 4, 4,
4, 4, 5, 6, 7,
7, 7, 7, 7, 7,
7, 7, 7, 8, 9,
9, 9, 9, 9, 9
1,
1,
1,
1,
1,
1,
2,
3,
4,
4,
4,
4,
5,
6,
7,
7,
7,
7,
7,
7,
7,
7,
7,
8,
9,
9,
9,
9,
9,
9,
]
for idx in data:
CursorPaginationModel.objects.create(created=idx)
@ -914,7 +988,7 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
"""
request = Request(factory.get(url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
current = [item['created'] for item in queryset]
current = [item["created"] for item in queryset]
next_url = self.pagination.get_next_link()
previous_url = self.pagination.get_previous_link()
@ -922,14 +996,14 @@ class TestCursorPaginationWithValueQueryset(CursorPaginationTestsMixin, TestCase
if next_url is not None:
request = Request(factory.get(next_url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
next = [item['created'] for item in queryset]
next = [item["created"] for item in queryset]
else:
next = None
if previous_url is not None:
request = Request(factory.get(previous_url))
queryset = self.pagination.paginate_queryset(self.queryset, request)
previous = [item['created'] for item in queryset]
previous = [item["created"] for item in queryset]
else:
previous = None

View File

@ -7,7 +7,8 @@ import math
import pytest
from django import forms
from django.core.files.uploadhandler import (
MemoryFileUploadHandler, TemporaryFileUploadHandler
MemoryFileUploadHandler,
TemporaryFileUploadHandler,
)
from django.http.request import RawPostDataException
from django.test import TestCase
@ -15,7 +16,10 @@ from django.utils.six import StringIO
from rest_framework.exceptions import ParseError
from rest_framework.parsers import (
FileUploadParser, FormParser, JSONParser, MultiPartParser
FileUploadParser,
FormParser,
JSONParser,
MultiPartParser,
)
from rest_framework.request import Request
from rest_framework.test import APIRequestFactory
@ -44,16 +48,15 @@ class TestFileUploadParser(TestCase):
def setUp(self):
class MockRequest(object):
pass
self.stream = io.BytesIO(
"Test text file".encode('utf-8')
)
self.stream = io.BytesIO("Test text file".encode("utf-8"))
request = MockRequest()
request.upload_handlers = (MemoryFileUploadHandler(),)
request.META = {
'HTTP_CONTENT_DISPOSITION': 'Content-Disposition: inline; filename=file.txt',
'HTTP_CONTENT_LENGTH': 14,
"HTTP_CONTENT_DISPOSITION": "Content-Disposition: inline; filename=file.txt",
"HTTP_CONTENT_LENGTH": 14,
}
self.parser_context = {'request': request, 'kwargs': {}}
self.parser_context = {"request": request, "kwargs": {}}
def test_parse(self):
"""
@ -62,7 +65,7 @@ class TestFileUploadParser(TestCase):
parser = FileUploadParser()
self.stream.seek(0)
data_and_files = parser.parse(self.stream, None, self.parser_context)
file_obj = data_and_files.files['file']
file_obj = data_and_files.files["file"]
assert file_obj.size == 14
def test_parse_missing_filename(self):
@ -71,10 +74,13 @@ class TestFileUploadParser(TestCase):
"""
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
assert (
str(excinfo.value)
== "Missing filename. Request should include a Content-Disposition header with a filename parameter."
)
def test_parse_missing_filename_multiple_upload_handlers(self):
"""
@ -83,14 +89,17 @@ class TestFileUploadParser(TestCase):
"""
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].upload_handlers = (
self.parser_context["request"].upload_handlers = (
MemoryFileUploadHandler(),
MemoryFileUploadHandler(),
MemoryFileUploadHandler()
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
assert (
str(excinfo.value)
== "Missing filename. Request should include a Content-Disposition header with a filename parameter."
)
def test_parse_missing_filename_large_file(self):
"""
@ -98,54 +107,59 @@ class TestFileUploadParser(TestCase):
"""
parser = FileUploadParser()
self.stream.seek(0)
self.parser_context['request'].upload_handlers = (
TemporaryFileUploadHandler(),
)
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = ''
self.parser_context["request"].upload_handlers = (TemporaryFileUploadHandler(),)
self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = ""
with pytest.raises(ParseError) as excinfo:
parser.parse(self.stream, None, self.parser_context)
assert str(excinfo.value) == 'Missing filename. Request should include a Content-Disposition header with a filename parameter.'
assert (
str(excinfo.value)
== "Missing filename. Request should include a Content-Disposition header with a filename parameter."
)
def test_get_filename(self):
parser = FileUploadParser()
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'file.txt'
assert filename == "file.txt"
def test_get_encoded_filename(self):
parser = FileUploadParser()
self.__replace_content_disposition('inline; filename*=utf-8\'\'ÀĥƦ.txt')
self.__replace_content_disposition("inline; filename*=utf-8''ÀĥƦ.txt")
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
assert filename == "ÀĥƦ.txt"
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'\'ÀĥƦ.txt')
self.__replace_content_disposition(
"inline; filename=fallback.txt; filename*=utf-8''ÀĥƦ.txt"
)
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
assert filename == "ÀĥƦ.txt"
self.__replace_content_disposition('inline; filename=fallback.txt; filename*=utf-8\'en-us\'ÀĥƦ.txt')
self.__replace_content_disposition(
"inline; filename=fallback.txt; filename*=utf-8'en-us'ÀĥƦ.txt"
)
filename = parser.get_filename(self.stream, None, self.parser_context)
assert filename == 'ÀĥƦ.txt'
assert filename == "ÀĥƦ.txt"
def __replace_content_disposition(self, disposition):
self.parser_context['request'].META['HTTP_CONTENT_DISPOSITION'] = disposition
self.parser_context["request"].META["HTTP_CONTENT_DISPOSITION"] = disposition
class TestJSONParser(TestCase):
def bytes(self, value):
return io.BytesIO(value.encode('utf-8'))
return io.BytesIO(value.encode("utf-8"))
def test_float_strictness(self):
parser = JSONParser()
# Default to strict
for value in ['Infinity', '-Infinity', 'NaN']:
for value in ["Infinity", "-Infinity", "NaN"]:
with pytest.raises(ParseError):
parser.parse(self.bytes(value))
parser.strict = False
assert parser.parse(self.bytes('Infinity')) == float('inf')
assert parser.parse(self.bytes('-Infinity')) == float('-inf')
assert math.isnan(parser.parse(self.bytes('NaN')))
assert parser.parse(self.bytes("Infinity")) == float("inf")
assert parser.parse(self.bytes("-Infinity")) == float("-inf")
assert math.isnan(parser.parse(self.bytes("NaN")))
class TestPOSTAccessed(TestCase):
@ -153,28 +167,28 @@ class TestPOSTAccessed(TestCase):
self.factory = APIRequestFactory()
def test_post_accessed_in_post_method(self):
django_request = self.factory.post('/', {'foo': 'bar'})
django_request = self.factory.post("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']}
assert request.POST == {"foo": ["bar"]}
assert request.data == {"foo": ["bar"]}
def test_post_accessed_in_post_method_with_json_parser(self):
django_request = self.factory.post('/', {'foo': 'bar'})
django_request = self.factory.post("/", {"foo": "bar"})
request = Request(django_request, parsers=[JSONParser()])
django_request.POST
assert request.POST == {}
assert request.data == {}
def test_post_accessed_in_put_method(self):
django_request = self.factory.put('/', {'foo': 'bar'})
django_request = self.factory.put("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.POST
assert request.POST == {'foo': ['bar']}
assert request.data == {'foo': ['bar']}
assert request.POST == {"foo": ["bar"]}
assert request.data == {"foo": ["bar"]}
def test_request_read_before_parsing(self):
django_request = self.factory.put('/', {'foo': 'bar'})
django_request = self.factory.put("/", {"foo": "bar"})
request = Request(django_request, parsers=[FormParser(), MultiPartParser()])
django_request.read()
with pytest.raises(RawPostDataException):

Some files were not shown because too many files have changed in this diff Show More